diff --git a/.github/workflows/bot.yml b/.github/workflows/bot.yml index 7b4376b04ce4b..1aaaba8eb6aff 100644 --- a/.github/workflows/bot.yml +++ b/.github/workflows/bot.yml @@ -80,6 +80,10 @@ jobs: sparkProfile: "spark3.4" sparkModules: "hudi-spark-datasource/hudi-spark3.4.x" + - scalaProfile: "scala-2.12" + sparkProfile: "spark3.5" + sparkModules: "hudi-spark-datasource/hudi-spark3.5.x" + steps: - uses: actions/checkout@v3 - name: Set up JDK 8 @@ -163,6 +167,9 @@ jobs: - scalaProfile: "scala-2.12" sparkProfile: "spark3.4" sparkModules: "hudi-spark-datasource/hudi-spark3.4.x" + - scalaProfile: "scala-2.12" + sparkProfile: "spark3.5" + sparkModules: "hudi-spark-datasource/hudi-spark3.5.x" steps: - uses: actions/checkout@v3 @@ -255,6 +262,9 @@ jobs: strategy: matrix: include: + - flinkProfile: 'flink1.18' + sparkProfile: 'spark3.5' + sparkRuntime: 'spark3.5.0' - flinkProfile: 'flink1.18' sparkProfile: 'spark3.4' sparkRuntime: 'spark3.4.0' @@ -284,6 +294,9 @@ jobs: strategy: matrix: include: + - flinkProfile: 'flink1.18' + sparkProfile: 'spark3.5' + sparkRuntime: 'spark3.5.0' - flinkProfile: 'flink1.18' sparkProfile: 'spark3.4' sparkRuntime: 'spark3.4.0' diff --git a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala index a0fe879b3dbea..527864fcf244a 100644 --- a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala +++ b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala @@ -51,6 +51,7 @@ private[hudi] trait SparkVersionsSupport { def isSpark3_2: Boolean = getSparkVersion.startsWith("3.2") def isSpark3_3: Boolean = getSparkVersion.startsWith("3.3") def isSpark3_4: Boolean = getSparkVersion.startsWith("3.4") + def isSpark3_5: Boolean = getSparkVersion.startsWith("3.5") def gteqSpark3_0: Boolean = getSparkVersion >= "3.0" def gteqSpark3_1: Boolean = getSparkVersion >= "3.1" @@ -61,6 +62,7 @@ private[hudi] trait SparkVersionsSupport { def gteqSpark3_3: Boolean = getSparkVersion >= "3.3" def gteqSpark3_3_2: Boolean = getSparkVersion >= "3.3.2" def gteqSpark3_4: Boolean = getSparkVersion >= "3.4" + def gteqSpark3_5: Boolean = getSparkVersion >= "3.5" } object HoodieSparkUtils extends SparkAdapterSupport with SparkVersionsSupport with Logging { diff --git a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/SparkAdapterSupport.scala b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/SparkAdapterSupport.scala index 7e035a95ef5fb..09229d74b2059 100644 --- a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/SparkAdapterSupport.scala +++ b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/SparkAdapterSupport.scala @@ -33,7 +33,9 @@ trait SparkAdapterSupport { object SparkAdapterSupport { lazy val sparkAdapter: SparkAdapter = { - val adapterClass = if (HoodieSparkUtils.isSpark3_4) { + val adapterClass = if (HoodieSparkUtils.isSpark3_5) { + "org.apache.spark.sql.adapter.Spark3_5Adapter" + } else if (HoodieSparkUtils.isSpark3_4) { "org.apache.spark.sql.adapter.Spark3_4Adapter" } else if (HoodieSparkUtils.isSpark3_3) { "org.apache.spark.sql.adapter.Spark3_3Adapter" diff --git a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/DataFrameUtil.scala b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/DataFrameUtil.scala index 290b118bd8978..11ccc59388ebb 100644 --- a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/DataFrameUtil.scala +++ b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/DataFrameUtil.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql +import org.apache.hudi.SparkAdapterSupport import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.LogicalRDD @@ -31,7 +32,8 @@ object DataFrameUtil { */ def createFromInternalRows(sparkSession: SparkSession, schema: StructType, rdd: RDD[InternalRow]): DataFrame = { - val logicalPlan = LogicalRDD(schema.toAttributes, rdd)(sparkSession) + val logicalPlan = LogicalRDD( + SparkAdapterSupport.sparkAdapter.getSchemaUtils.toAttributes(schema), rdd)(sparkSession) Dataset.ofRows(sparkSession, logicalPlan) } -} \ No newline at end of file +} diff --git a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieCatalystExpressionUtils.scala b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieCatalystExpressionUtils.scala index a83afd514f1c3..df55a19db441c 100644 --- a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieCatalystExpressionUtils.scala +++ b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieCatalystExpressionUtils.scala @@ -18,20 +18,22 @@ package org.apache.spark.sql import org.apache.hudi.SparkAdapterSupport -import org.apache.hudi.SparkAdapterSupport.sparkAdapter -import org.apache.hudi.common.util.ValidationUtils.checkState import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction} -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateMutableProjection, GenerateUnsafeProjection} -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeEq, AttributeReference, Cast, Expression, Like, Literal, MutableProjection, SubqueryExpression, UnsafeProjection} -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, CreateStruct, Expression, GetStructField, Like, Literal, Projection, SubqueryExpression, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeEq, AttributeReference, AttributeSet, Cast, Expression, Like, Literal, SubqueryExpression, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StructType} trait HoodieCatalystExpressionUtils { + /** + * SPARK-44531 Encoder inference moved elsewhere in Spark 3.5.0 + * Mainly used for unit tests + */ + def getEncoder(schema: StructType): ExpressionEncoder[Row] + /** * Returns a filter that its reference is a subset of `outputSet` and it contains the maximum * constraints from `condition`. This is used for predicate push-down @@ -269,7 +271,7 @@ object HoodieCatalystExpressionUtils extends SparkAdapterSupport { } private def generateUnsafeProjectionInternal(from: StructType, to: StructType): UnsafeProjection = { - val attrs = from.toAttributes + val attrs = sparkAdapter.getSchemaUtils.toAttributes(from) val attrsMap = attrs.map(attr => (attr.name, attr)).toMap val targetExprs = to.fields.map(f => attrsMap(f.name)) diff --git a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieSchemaUtils.scala b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieSchemaUtils.scala index 2ee323ec37008..2ee489ada4d5e 100644 --- a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieSchemaUtils.scala +++ b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieSchemaUtils.scala @@ -19,6 +19,9 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.types.StructType + /** * Utils on schema, which have different implementation across Spark versions. */ @@ -34,4 +37,10 @@ trait HoodieSchemaUtils { def checkColumnNameDuplication(columnNames: Seq[String], colType: String, caseSensitiveAnalysis: Boolean): Unit + + /** + * SPARK-44353 StructType#toAttributes was removed in Spark 3.5.0 + * Use DataTypeUtils#toAttributes for Spark 3.5+ + */ + def toAttributes(struct: StructType): Seq[Attribute] } diff --git a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieUnsafeUtils.scala b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieUnsafeUtils.scala index ee22f714c9c90..138815bc9c848 100644 --- a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieUnsafeUtils.scala +++ b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieUnsafeUtils.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql -import org.apache.hudi.HoodieUnsafeRDD +import org.apache.hudi.{HoodieUnsafeRDD, SparkAdapterSupport} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} @@ -68,14 +68,15 @@ object HoodieUnsafeUtils { * Creates [[DataFrame]] from the in-memory [[Seq]] of [[Row]]s with provided [[schema]] * * NOTE: [[DataFrame]] is based on [[LocalRelation]], entailing that most computations with it - * will be executed by Spark locally + * will be executed by Spark locally * - * @param spark spark's session - * @param rows collection of rows to base [[DataFrame]] on + * @param spark spark's session + * @param rows collection of rows to base [[DataFrame]] on * @param schema target [[DataFrame]]'s schema */ def createDataFrameFromRows(spark: SparkSession, rows: Seq[Row], schema: StructType): DataFrame = - Dataset.ofRows(spark, LocalRelation.fromExternalRows(schema.toAttributes, rows)) + Dataset.ofRows(spark, LocalRelation.fromExternalRows( + SparkAdapterSupport.sparkAdapter.getSchemaUtils.toAttributes(schema), rows)) /** * Creates [[DataFrame]] from the in-memory [[Seq]] of [[InternalRow]]s with provided [[schema]] @@ -88,7 +89,7 @@ object HoodieUnsafeUtils { * @param schema target [[DataFrame]]'s schema */ def createDataFrameFromInternalRows(spark: SparkSession, rows: Seq[InternalRow], schema: StructType): DataFrame = - Dataset.ofRows(spark, LocalRelation(schema.toAttributes, rows)) + Dataset.ofRows(spark, LocalRelation(SparkAdapterSupport.sparkAdapter.getSchemaUtils.toAttributes(schema), rows)) /** diff --git a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/execution/datasources/HoodieSparkPartitionedFileUtils.scala b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/execution/datasources/HoodieSparkPartitionedFileUtils.scala index 0e3b3f261d824..53d95f09394be 100644 --- a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/execution/datasources/HoodieSparkPartitionedFileUtils.scala +++ b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/execution/datasources/HoodieSparkPartitionedFileUtils.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.execution.datasources -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.spark.sql.catalyst.InternalRow /** - * Utils on Spark [[PartitionedFile]] to adapt to type changes. + * Utils on Spark [[PartitionedFile]] and [[PartitionDirectory]] to adapt to type changes. * Before Spark 3.4.0, * ``` * case class PartitionedFile( @@ -65,13 +65,23 @@ trait HoodieSparkPartitionedFileUtils extends Serializable { * Creates a new [[PartitionedFile]] instance. * * @param partitionValues value of partition columns to be prepended to each row. - * @param filePath URI of the file to read. - * @param start the beginning offset (in bytes) of the block. - * @param length number of bytes to read. + * @param filePath URI of the file to read. + * @param start the beginning offset (in bytes) of the block. + * @param length number of bytes to read. * @return a new [[PartitionedFile]] instance. */ def createPartitionedFile(partitionValues: InternalRow, filePath: Path, start: Long, length: Long): PartitionedFile + + /** + * SPARK-43039 FileIndex#PartitionDirectory refactored in Spark 3.5.0 + */ + def toFileStatuses(partitionDirs: Seq[PartitionDirectory]): Seq[FileStatus] + + /** + * SPARK-43039 FileIndex#PartitionDirectory refactored in Spark 3.5.0 + */ + def newPartitionDirectory(internalRow: InternalRow, statuses: Seq[FileStatus]): PartitionDirectory } diff --git a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/hudi/SparkAdapter.scala b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/hudi/SparkAdapter.scala index 1c6111afe47f3..5691dd5c3805b 100644 --- a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/hudi/SparkAdapter.scala +++ b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/hudi/SparkAdapter.scala @@ -19,14 +19,15 @@ package org.apache.spark.sql.hudi import org.apache.avro.Schema -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hudi.client.utils.SparkRowSerDe import org.apache.hudi.common.table.HoodieTableMetaClient import org.apache.spark.sql._ import org.apache.spark.sql.avro.{HoodieAvroDeserializer, HoodieAvroSchemaConverters, HoodieAvroSerializer} import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.catalog.CatalogTable -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, InterpretedPredicate} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, InterpretedPredicate} import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} diff --git a/hudi-common/src/test/java/org/apache/hudi/avro/TestHoodieAvroUtils.java b/hudi-common/src/test/java/org/apache/hudi/avro/TestHoodieAvroUtils.java index 2993084cfc607..28b054352445f 100644 --- a/hudi-common/src/test/java/org/apache/hudi/avro/TestHoodieAvroUtils.java +++ b/hudi-common/src/test/java/org/apache/hudi/avro/TestHoodieAvroUtils.java @@ -295,7 +295,7 @@ public void testRemoveFields() { // partitioned table test. String schemaStr = "{\"type\": \"record\",\"name\": \"testrec\",\"fields\": [ " + "{\"name\": \"timestamp\",\"type\": \"double\"},{\"name\": \"_row_key\", \"type\": \"string\"}," - + "{\"name\": \"non_pii_col\", \"type\": \"string\"}]},"; + + "{\"name\": \"non_pii_col\", \"type\": \"string\"}]}"; Schema expectedSchema = new Schema.Parser().parse(schemaStr); GenericRecord rec = new GenericData.Record(new Schema.Parser().parse(EXAMPLE_SCHEMA)); rec.put("_row_key", "key1"); @@ -318,7 +318,7 @@ public void testRemoveFields() { schemaStr = "{\"type\": \"record\",\"name\": \"testrec\",\"fields\": [ " + "{\"name\": \"timestamp\",\"type\": \"double\"},{\"name\": \"_row_key\", \"type\": \"string\"}," + "{\"name\": \"non_pii_col\", \"type\": \"string\"}," - + "{\"name\": \"pii_col\", \"type\": \"string\"}]},"; + + "{\"name\": \"pii_col\", \"type\": \"string\"}]}"; expectedSchema = new Schema.Parser().parse(schemaStr); rec1 = HoodieAvroUtils.removeFields(rec, Collections.singleton("")); assertEquals(expectedSchema, rec1.getSchema()); diff --git a/hudi-common/src/test/java/org/apache/hudi/common/util/TestClusteringUtils.java b/hudi-common/src/test/java/org/apache/hudi/common/util/TestClusteringUtils.java index 45d2b41129106..fb3adaf1c018a 100644 --- a/hudi-common/src/test/java/org/apache/hudi/common/util/TestClusteringUtils.java +++ b/hudi-common/src/test/java/org/apache/hudi/common/util/TestClusteringUtils.java @@ -37,6 +37,7 @@ import org.apache.hudi.common.testutils.HoodieCommonTestHarness; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import java.io.IOException; @@ -107,6 +108,7 @@ public void testClusteringPlanMultipleInstants() throws Exception { // replacecommit.inflight doesn't have clustering plan. // Verify that getClusteringPlan fetches content from corresponding requested file. + @Disabled("Will fail due to avro issue AVRO-3789. This is fixed in avro 1.11.3") @Test public void testClusteringPlanInflight() throws Exception { String partitionPath1 = "partition1"; diff --git a/hudi-integ-test/src/main/java/org/apache/hudi/integ/testsuite/dag/nodes/BaseValidateDatasetNode.java b/hudi-integ-test/src/main/java/org/apache/hudi/integ/testsuite/dag/nodes/BaseValidateDatasetNode.java index 8f86421c77243..892730c675b7e 100644 --- a/hudi-integ-test/src/main/java/org/apache/hudi/integ/testsuite/dag/nodes/BaseValidateDatasetNode.java +++ b/hudi-integ-test/src/main/java/org/apache/hudi/integ/testsuite/dag/nodes/BaseValidateDatasetNode.java @@ -20,6 +20,7 @@ package org.apache.hudi.integ.testsuite.dag.nodes; import org.apache.hudi.DataSourceWriteOptions; +import org.apache.hudi.SparkAdapterSupport$; import org.apache.hudi.common.model.HoodieCommitMetadata; import org.apache.hudi.common.table.HoodieTableMetaClient; import org.apache.hudi.common.table.timeline.HoodieTimeline; @@ -40,10 +41,7 @@ import org.apache.spark.sql.Encoders; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; -import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer$; import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder; -import org.apache.spark.sql.catalyst.encoders.RowEncoder; -import org.apache.spark.sql.catalyst.expressions.Attribute; import org.apache.spark.sql.types.StructType; import org.slf4j.Logger; @@ -51,11 +49,8 @@ import java.util.Arrays; import java.util.Comparator; import java.util.List; -import java.util.stream.Collectors; import scala.Tuple2; -import scala.collection.JavaConversions; -import scala.collection.JavaConverters; import static org.apache.hudi.utilities.deltastreamer.HoodieDeltaStreamer.CHECKPOINT_KEY; @@ -244,10 +239,6 @@ private Dataset getInputDf(ExecutionContext context, SparkSession session, } private ExpressionEncoder getEncoder(StructType schema) { - List attributes = JavaConversions.asJavaCollection(schema.toAttributes()).stream() - .map(Attribute::toAttribute).collect(Collectors.toList()); - return RowEncoder.apply(schema) - .resolveAndBind(JavaConverters.asScalaBufferConverter(attributes).asScala().toSeq(), - SimpleAnalyzer$.MODULE$); + return SparkAdapterSupport$.MODULE$.sparkAdapter().getCatalystExpressionUtils().getEncoder(schema); } } diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieBaseRelation.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieBaseRelation.scala index eaeff8bc7e9fd..e8cd59bb2dd20 100644 --- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieBaseRelation.scala +++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieBaseRelation.scala @@ -66,7 +66,6 @@ import org.apache.spark.sql.{Row, SQLContext, SparkSession} import java.net.URI import scala.collection.JavaConverters._ -import scala.util.control.NonFatal import scala.util.{Failure, Success, Try} trait HoodieFileSplit {} @@ -423,7 +422,8 @@ abstract class HoodieBaseRelation(val sqlContext: SQLContext, inMemoryFileIndex.listFiles(partitionFilters, dataFilters) } - val fsView = new HoodieTableFileSystemView(metaClient, timeline, partitionDirs.flatMap(_.files).toArray) + val fsView = new HoodieTableFileSystemView( + metaClient, timeline, sparkAdapter.getSparkPartitionedFileUtils.toFileStatuses(partitionDirs).toArray) fsView.getPartitionPaths.asScala.flatMap { partitionPath => val relativePath = getRelativePartitionPath(basePath, partitionPath) diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieCDCFileIndex.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieCDCFileIndex.scala index 0fe3dcd2b5157..2bcea15df4b2c 100644 --- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieCDCFileIndex.scala +++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieCDCFileIndex.scala @@ -59,7 +59,7 @@ class HoodieCDCFileIndex (override val spark: SparkSession, }}.toList val partitionValues: InternalRow = new GenericInternalRow(doParsePartitionColumnValues( metaClient.getTableConfig.getPartitionFields.get(), partitionPath).asInstanceOf[Array[Any]]) - PartitionDirectory( + sparkAdapter.getSparkPartitionedFileUtils.newPartitionDirectory( new HoodiePartitionCDCFileGroupMapping( partitionValues, fileGroups.map(kv => kv._1 -> kv._2.asScala.toList).toMap), fileGroupIds diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieFileIndex.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieFileIndex.scala index 2767b6982cd12..580a03dae709f 100644 --- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieFileIndex.scala +++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieFileIndex.scala @@ -167,9 +167,11 @@ case class HoodieFileIndex(spark: SparkSession, || (f.getBaseFile.isPresent && f.getBaseFile.get().getBootstrapBaseFile.isPresent)). foldLeft(Map[String, FileSlice]()) { (m, f) => m + (f.getFileId -> f) } if (c.nonEmpty) { - PartitionDirectory(new HoodiePartitionFileSliceMapping(InternalRow.fromSeq(partitionOpt.get.values), c), baseFileStatusesAndLogFileOnly) + sparkAdapter.getSparkPartitionedFileUtils.newPartitionDirectory( + new HoodiePartitionFileSliceMapping(InternalRow.fromSeq(partitionOpt.get.values), c), baseFileStatusesAndLogFileOnly) } else { - PartitionDirectory(InternalRow.fromSeq(partitionOpt.get.values), baseFileStatusesAndLogFileOnly) + sparkAdapter.getSparkPartitionedFileUtils.newPartitionDirectory( + InternalRow.fromSeq(partitionOpt.get.values), baseFileStatusesAndLogFileOnly) } } else { @@ -184,7 +186,8 @@ case class HoodieFileIndex(spark: SparkSession, baseFileStatusOpt.foreach(f => files.append(f)) files }) - PartitionDirectory(InternalRow.fromSeq(partitionOpt.get.values), allCandidateFiles) + sparkAdapter.getSparkPartitionedFileUtils.newPartitionDirectory( + InternalRow.fromSeq(partitionOpt.get.values), allCandidateFiles) } } diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieIncrementalFileIndex.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieIncrementalFileIndex.scala index f68a0cac63b33..7e765b4d08f3a 100644 --- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieIncrementalFileIndex.scala +++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieIncrementalFileIndex.scala @@ -67,9 +67,11 @@ class HoodieIncrementalFileIndex(override val spark: SparkSession, || (f.getBaseFile.isPresent && f.getBaseFile.get().getBootstrapBaseFile.isPresent)). foldLeft(Map[String, FileSlice]()) { (m, f) => m + (f.getFileId -> f) } if (c.nonEmpty) { - PartitionDirectory(new HoodiePartitionFileSliceMapping(partitionValues, c), baseFileStatusesAndLogFileOnly) + sparkAdapter.getSparkPartitionedFileUtils.newPartitionDirectory( + new HoodiePartitionFileSliceMapping(partitionValues, c), baseFileStatusesAndLogFileOnly) } else { - PartitionDirectory(partitionValues, baseFileStatusesAndLogFileOnly) + sparkAdapter.getSparkPartitionedFileUtils.newPartitionDirectory( + partitionValues, baseFileStatusesAndLogFileOnly) } } else { val allCandidateFiles: Seq[FileStatus] = fileSlices.flatMap(fs => { @@ -83,7 +85,8 @@ class HoodieIncrementalFileIndex(override val spark: SparkSession, baseFileStatusOpt.foreach(f => files.append(f)) files }) - PartitionDirectory(partitionValues, allCandidateFiles) + sparkAdapter.getSparkPartitionedFileUtils.newPartitionDirectory( + partitionValues, allCandidateFiles) } }.toSeq diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/execution/datasources/HoodieInMemoryFileIndex.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/execution/datasources/HoodieInMemoryFileIndex.scala index ad1e87f8ce04a..e69364d676601 100644 --- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/execution/datasources/HoodieInMemoryFileIndex.scala +++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/execution/datasources/HoodieInMemoryFileIndex.scala @@ -49,7 +49,8 @@ class HoodieInMemoryFileIndex(sparkSession: SparkSession, */ override def listFiles(partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): Seq[PartitionDirectory] = { val selectedPartitions = if (partitionSpec().partitionColumns.isEmpty) { - PartitionDirectory(InternalRow.empty, allFiles().filter(f => isDataPath(f.getPath))) :: Nil + sparkAdapter.getSparkPartitionedFileUtils.newPartitionDirectory( + InternalRow.empty, allFiles().filter(f => isDataPath(f.getPath))) :: Nil } else { prunePartitions(partitionFilters, partitionSpec()).map { case PartitionPath(values, path) => @@ -62,7 +63,7 @@ class HoodieInMemoryFileIndex(sparkSession: SparkSession, // Directory does not exist, or has no children files Nil } - PartitionDirectory(values, files) + sparkAdapter.getSparkPartitionedFileUtils.newPartitionDirectory(values, files) } } logTrace("Selected files after partition pruning:\n\t" + selectedPartitions.mkString("\n\t")) diff --git a/hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/testutils/SparkDatasetTestUtils.java b/hudi-spark-datasource/hudi-spark-common/src/test/java/org/apache/hudi/testutils/SparkDatasetTestUtils.java similarity index 94% rename from hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/testutils/SparkDatasetTestUtils.java rename to hudi-spark-datasource/hudi-spark-common/src/test/java/org/apache/hudi/testutils/SparkDatasetTestUtils.java index 1e07378f9afd0..494e17ad76899 100644 --- a/hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/testutils/SparkDatasetTestUtils.java +++ b/hudi-spark-datasource/hudi-spark-common/src/test/java/org/apache/hudi/testutils/SparkDatasetTestUtils.java @@ -18,13 +18,14 @@ package org.apache.hudi.testutils; +import org.apache.hudi.SparkAdapterSupport$; +import org.apache.hudi.common.config.HoodieStorageConfig; import org.apache.hudi.common.model.HoodieKey; import org.apache.hudi.common.model.HoodieRecord; import org.apache.hudi.common.table.view.FileSystemViewStorageConfig; import org.apache.hudi.common.testutils.HoodieTestDataGenerator; import org.apache.hudi.config.HoodieCompactionConfig; import org.apache.hudi.config.HoodieIndexConfig; -import org.apache.hudi.common.config.HoodieStorageConfig; import org.apache.hudi.config.HoodieWriteConfig; import org.apache.hudi.index.HoodieIndex; @@ -33,10 +34,7 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer$; import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder; -import org.apache.spark.sql.catalyst.encoders.RowEncoder; -import org.apache.spark.sql.catalyst.expressions.Attribute; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.catalyst.expressions.GenericRow; import org.apache.spark.sql.types.DataTypes; @@ -49,15 +47,14 @@ import java.util.ArrayList; import java.util.List; import java.util.UUID; -import java.util.stream.Collectors; - -import scala.collection.JavaConversions; -import scala.collection.JavaConverters; import static org.apache.hudi.common.testutils.FileSystemTestUtils.RANDOM; /** * Dataset test utils. + * Note: This util class can be only used within `hudi-spark` modules because it + * relies on SparkAdapterSupport to get encoder for different versions of Spark. If used elsewhere this + * class won't be initialized properly amd could cause ClassNotFoundException or NoClassDefFoundError */ public class SparkDatasetTestUtils { @@ -96,11 +93,7 @@ public class SparkDatasetTestUtils { * @return the encoder thus generated. */ private static ExpressionEncoder getEncoder(StructType schema) { - List attributes = JavaConversions.asJavaCollection(schema.toAttributes()).stream() - .map(Attribute::toAttribute).collect(Collectors.toList()); - return RowEncoder.apply(schema) - .resolveAndBind(JavaConverters.asScalaBufferConverter(attributes).asScala().toSeq(), - SimpleAnalyzer$.MODULE$); + return SparkAdapterSupport$.MODULE$.sparkAdapter().getCatalystExpressionUtils().getEncoder(schema); } /** diff --git a/hudi-spark-datasource/hudi-spark/pom.xml b/hudi-spark-datasource/hudi-spark/pom.xml index fb5373541ed98..9e6ca76337b65 100644 --- a/hudi-spark-datasource/hudi-spark/pom.xml +++ b/hudi-spark-datasource/hudi-spark/pom.xml @@ -245,6 +245,12 @@ org.apache.parquet parquet-avro + + org.apache.parquet + parquet-hadoop-bundle + ${parquet.version} + provided + @@ -335,6 +341,10 @@ org.pentaho * + + org.apache.parquet + * + @@ -350,6 +360,10 @@ javax.servlet.jsp * + + org.apache.parquet + * + @@ -365,6 +379,10 @@ javax.servlet.jsp * + + org.apache.parquet + * + @@ -376,6 +394,10 @@ org.eclipse.jetty.orbit javax.servlet + + org.apache.parquet + * + @@ -420,6 +442,14 @@ test-jar test + + org.apache.hudi + hudi-spark-common_${scala.binary.version} + ${project.version} + tests + test-jar + test + org.apache.hudi hudi-common diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala index 051c38d956775..9d4014021e594 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala @@ -77,7 +77,16 @@ object HoodieAnalysis extends SparkAdapterSupport { } } else { rules += adaptIngestionTargetLogicalRelations - val dataSourceV2ToV1FallbackClass = "org.apache.spark.sql.hudi.analysis.HoodieDataSourceV2ToV1Fallback" + val dataSourceV2ToV1FallbackClass = if (HoodieSparkUtils.isSpark3_5) + "org.apache.spark.sql.hudi.analysis.HoodieSpark35DataSourceV2ToV1Fallback" + else if (HoodieSparkUtils.isSpark3_4) + "org.apache.spark.sql.hudi.analysis.HoodieSpark34DataSourceV2ToV1Fallback" + else if (HoodieSparkUtils.isSpark3_3) + "org.apache.spark.sql.hudi.analysis.HoodieSpark33DataSourceV2ToV1Fallback" + else { + // Spark 3.2.x + "org.apache.spark.sql.hudi.analysis.HoodieSpark32DataSourceV2ToV1Fallback" + } val dataSourceV2ToV1Fallback: RuleBuilder = session => instantiateKlass(dataSourceV2ToV1FallbackClass, session) @@ -95,7 +104,9 @@ object HoodieAnalysis extends SparkAdapterSupport { if (HoodieSparkUtils.isSpark3) { val resolveAlterTableCommandsClass = - if (HoodieSparkUtils.gteqSpark3_4) { + if (HoodieSparkUtils.gteqSpark3_5) { + "org.apache.spark.sql.hudi.Spark35ResolveHudiAlterTableCommand" + } else if (HoodieSparkUtils.gteqSpark3_4) { "org.apache.spark.sql.hudi.Spark34ResolveHudiAlterTableCommand" } else if (HoodieSparkUtils.gteqSpark3_3) { "org.apache.spark.sql.hudi.Spark33ResolveHudiAlterTableCommand" @@ -149,7 +160,9 @@ object HoodieAnalysis extends SparkAdapterSupport { if (HoodieSparkUtils.gteqSpark3_0) { val nestedSchemaPruningClass = - if (HoodieSparkUtils.gteqSpark3_4) { + if (HoodieSparkUtils.gteqSpark3_5) { + "org.apache.spark.sql.execution.datasources.Spark35NestedSchemaPruning" + } else if (HoodieSparkUtils.gteqSpark3_4) { "org.apache.spark.sql.execution.datasources.Spark34NestedSchemaPruning" } else if (HoodieSparkUtils.gteqSpark3_3) { "org.apache.spark.sql.execution.datasources.Spark33NestedSchemaPruning" diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/CallProcedureHoodieCommand.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/CallProcedureHoodieCommand.scala index f63f4115e9195..f185096961936 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/CallProcedureHoodieCommand.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/CallProcedureHoodieCommand.scala @@ -17,17 +17,17 @@ package org.apache.spark.sql.hudi.command +import org.apache.hudi.SparkAdapterSupport import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.hudi.command.procedures.{Procedure, ProcedureArgs} import org.apache.spark.sql.{Row, SparkSession} -import scala.collection.Seq - case class CallProcedureHoodieCommand( procedure: Procedure, args: ProcedureArgs) extends HoodieLeafRunnableCommand { - override def output: Seq[Attribute] = procedure.outputType.toAttributes + override def output: Seq[Attribute] = + SparkAdapterSupport.sparkAdapter.getSchemaUtils.toAttributes(procedure.outputType) override def run(sparkSession: SparkSession): Seq[Row] = { procedure.call(args) diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/CompactionHoodiePathCommand.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/CompactionHoodiePathCommand.scala index 57aff092b7429..5bb62524a2bc4 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/CompactionHoodiePathCommand.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/CompactionHoodiePathCommand.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.hudi.command +import org.apache.hudi.SparkAdapterSupport import org.apache.hudi.common.model.HoodieTableType import org.apache.hudi.common.table.HoodieTableMetaClient import org.apache.spark.sql.catalyst.expressions.Attribute @@ -48,5 +49,7 @@ case class CompactionHoodiePathCommand(path: String, RunCompactionProcedure.builder.get().build.call(procedureArgs) } - override val output: Seq[Attribute] = RunCompactionProcedure.builder.get().build.outputType.toAttributes + override val output: Seq[Attribute] = + SparkAdapterSupport.sparkAdapter.getSchemaUtils.toAttributes( + RunCompactionProcedure.builder.get().build.outputType) } diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/CompactionHoodieTableCommand.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/CompactionHoodieTableCommand.scala index adaaeae9e55c9..426d6f27720b4 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/CompactionHoodieTableCommand.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/CompactionHoodieTableCommand.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.hudi.command +import org.apache.hudi.SparkAdapterSupport import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.CompactionOperation.CompactionOperation @@ -35,5 +36,7 @@ case class CompactionHoodieTableCommand(table: CatalogTable, CompactionHoodiePathCommand(basePath, operation, instantTimestamp).run(sparkSession) } - override val output: Seq[Attribute] = RunCompactionProcedure.builder.get().build.outputType.toAttributes + override val output: Seq[Attribute] = + SparkAdapterSupport.sparkAdapter.getSchemaUtils.toAttributes( + RunCompactionProcedure.builder.get().build.outputType) } diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/CompactionShowHoodiePathCommand.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/CompactionShowHoodiePathCommand.scala index 95a4ecf7800e6..a61bea7aa8481 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/CompactionShowHoodiePathCommand.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/CompactionShowHoodiePathCommand.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.hudi.command +import org.apache.hudi.SparkAdapterSupport import org.apache.hudi.common.model.HoodieTableType import org.apache.hudi.common.table.HoodieTableMetaClient import org.apache.spark.sql.catalyst.expressions.Attribute @@ -40,5 +41,7 @@ case class CompactionShowHoodiePathCommand(path: String, limit: Int) ShowCompactionProcedure.builder.get().build.call(procedureArgs) } - override val output: Seq[Attribute] = ShowCompactionProcedure.builder.get().build.outputType.toAttributes + override val output: Seq[Attribute] = + SparkAdapterSupport.sparkAdapter.getSchemaUtils.toAttributes( + ShowCompactionProcedure.builder.get().build.outputType) } diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/CompactionShowHoodieTableCommand.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/CompactionShowHoodieTableCommand.scala index afd15d5153db6..070e93912aba0 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/CompactionShowHoodieTableCommand.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/CompactionShowHoodieTableCommand.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.hudi.command +import org.apache.hudi.SparkAdapterSupport import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.hudi.HoodieSqlCommonUtils.getTableLocation @@ -32,5 +33,7 @@ case class CompactionShowHoodieTableCommand(table: CatalogTable, limit: Int) CompactionShowHoodiePathCommand(basePath, limit).run(sparkSession) } - override val output: Seq[Attribute] = ShowCompactionProcedure.builder.get().build.outputType.toAttributes + override val output: Seq[Attribute] = + SparkAdapterSupport.sparkAdapter.getSchemaUtils.toAttributes( + ShowCompactionProcedure.builder.get().build.outputType) } diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/InsertIntoHoodieTableCommand.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/InsertIntoHoodieTableCommand.scala index b8d5be7638fb4..008118bfcce04 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/InsertIntoHoodieTableCommand.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/InsertIntoHoodieTableCommand.scala @@ -180,11 +180,15 @@ object InsertIntoHoodieTableCommand extends Logging with ProvidesHoodieConfig wi conf: SQLConf): LogicalPlan = { val planUtils = sparkAdapter.getCatalystPlanUtils try { - planUtils.resolveOutputColumns(catalogTable.catalogTableName, expectedSchema.toAttributes, query, byName = true, conf) + planUtils.resolveOutputColumns( + catalogTable.catalogTableName, sparkAdapter.getSchemaUtils.toAttributes(expectedSchema), query, byName = true, conf) } catch { // NOTE: In case matching by name didn't match the query output, we will attempt positional matching - case ae: AnalysisException if ae.getMessage().startsWith("Cannot write incompatible data to table") => - planUtils.resolveOutputColumns(catalogTable.catalogTableName, expectedSchema.toAttributes, query, byName = false, conf) + // SPARK-42309 Error message changed in Spark 3.5.0 so we need to match two strings here + case ae: AnalysisException if (ae.getMessage().startsWith("[INCOMPATIBLE_DATA_FOR_TABLE.CANNOT_FIND_DATA] Cannot write incompatible data for the table") + || ae.getMessage().startsWith("Cannot write incompatible data to table")) => + planUtils.resolveOutputColumns( + catalogTable.catalogTableName, sparkAdapter.getSchemaUtils.toAttributes(expectedSchema), query, byName = false, conf) } } diff --git a/hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/execution/bulkinsert/TestBulkInsertInternalPartitionerForRows.java b/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/execution/bulkinsert/TestBulkInsertInternalPartitionerForRows.java similarity index 100% rename from hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/execution/bulkinsert/TestBulkInsertInternalPartitionerForRows.java rename to hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/execution/bulkinsert/TestBulkInsertInternalPartitionerForRows.java diff --git a/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/functional/TestHoodieDatasetBulkInsertHelper.java b/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/functional/TestHoodieDatasetBulkInsertHelper.java index 8166820cb8795..52c4568697832 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/functional/TestHoodieDatasetBulkInsertHelper.java +++ b/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/functional/TestHoodieDatasetBulkInsertHelper.java @@ -17,10 +17,10 @@ package org.apache.hudi.functional; -import org.apache.avro.Schema; import org.apache.hudi.AvroConversionUtils; import org.apache.hudi.DataSourceWriteOptions; import org.apache.hudi.HoodieDatasetBulkInsertHelper; +import org.apache.hudi.SparkAdapterSupport$; import org.apache.hudi.common.config.TypedProperties; import org.apache.hudi.common.model.HoodieRecord; import org.apache.hudi.common.util.FileIOUtils; @@ -33,34 +33,31 @@ import org.apache.hudi.metadata.HoodieTableMetadata; import org.apache.hudi.testutils.DataSourceTestUtils; import org.apache.hudi.testutils.HoodieSparkClientTestBase; + +import org.apache.avro.Schema; import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.api.java.function.ReduceFunction; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.Row; -import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer$; import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder; -import org.apache.spark.sql.catalyst.encoders.RowEncoder; -import org.apache.spark.sql.catalyst.expressions.Attribute; import org.apache.spark.sql.types.StructType; import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; -import scala.Tuple2; -import scala.collection.JavaConversions; -import scala.collection.JavaConverters; import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; +import scala.Tuple2; + import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; @@ -349,10 +346,6 @@ public void testNoPropsSet() { } private ExpressionEncoder getEncoder(StructType schema) { - List attributes = JavaConversions.asJavaCollection(schema.toAttributes()).stream() - .map(Attribute::toAttribute).collect(Collectors.toList()); - return RowEncoder.apply(schema) - .resolveAndBind(JavaConverters.asScalaBufferConverter(attributes).asScala().toSeq(), - SimpleAnalyzer$.MODULE$); + return SparkAdapterSupport$.MODULE$.sparkAdapter().getCatalystExpressionUtils().getEncoder(schema); } } diff --git a/hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/io/storage/row/TestHoodieInternalRowParquetWriter.java b/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/io/storage/row/TestHoodieInternalRowParquetWriter.java similarity index 100% rename from hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/io/storage/row/TestHoodieInternalRowParquetWriter.java rename to hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/io/storage/row/TestHoodieInternalRowParquetWriter.java diff --git a/hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/io/storage/row/TestHoodieRowCreateHandle.java b/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/io/storage/row/TestHoodieRowCreateHandle.java similarity index 94% rename from hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/io/storage/row/TestHoodieRowCreateHandle.java rename to hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/io/storage/row/TestHoodieRowCreateHandle.java index a88f4dcf9e89c..86aa6cff7a3d7 100644 --- a/hudi-client/hudi-spark-client/src/test/java/org/apache/hudi/io/storage/row/TestHoodieRowCreateHandle.java +++ b/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/io/storage/row/TestHoodieRowCreateHandle.java @@ -45,6 +45,8 @@ import java.util.Random; import java.util.UUID; +import static org.apache.hudi.common.testutils.HoodieTestUtils.getJavaVersion; + import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -166,7 +168,17 @@ public void testGlobalFailure() throws Exception { fileNames.add(handle.getFileName()); // verify write status assertNotNull(writeStatus.getGlobalError()); - assertTrue(writeStatus.getGlobalError().getMessage().contains("java.lang.String cannot be cast to org.apache.spark.unsafe.types.UTF8String")); + + String expectedError = getJavaVersion() == 11 || getJavaVersion() == 17 + ? "class java.lang.String cannot be cast to class org.apache.spark.unsafe.types.UTF8String" + : "java.lang.String cannot be cast to org.apache.spark.unsafe.types.UTF8String"; + + try { + assertTrue(writeStatus.getGlobalError().getMessage().contains(expectedError)); + } catch (Throwable e) { + fail("Expected error to contain: " + expectedError + ", the actual error message: " + writeStatus.getGlobalError().getMessage()); + } + assertEquals(writeStatus.getFileId(), fileId); assertEquals(writeStatus.getPartitionPath(), partitionPath); diff --git a/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/testutils/KeyGeneratorTestUtilities.java b/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/testutils/KeyGeneratorTestUtilities.java index e1f8f9f6105ec..d704e833ba082 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/testutils/KeyGeneratorTestUtilities.java +++ b/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/testutils/KeyGeneratorTestUtilities.java @@ -18,27 +18,23 @@ package org.apache.hudi.testutils; +import org.apache.hudi.AvroConversionUtils; +import org.apache.hudi.SparkAdapterSupport$; + import org.apache.avro.Schema; import org.apache.avro.generic.GenericData; import org.apache.avro.generic.GenericRecord; -import org.apache.hudi.AvroConversionUtils; import org.apache.spark.package$; import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer$; import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder; -import org.apache.spark.sql.catalyst.encoders.RowEncoder; -import org.apache.spark.sql.catalyst.expressions.Attribute; import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema; import org.apache.spark.sql.types.StructType; -import scala.Function1; -import scala.collection.JavaConversions; -import scala.collection.JavaConverters; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; -import java.util.List; -import java.util.stream.Collectors; + +import scala.Function1; public class KeyGeneratorTestUtilities { @@ -101,11 +97,7 @@ public static InternalRow getInternalRow(Row row) { } private static ExpressionEncoder getEncoder(StructType schema) { - List attributes = JavaConversions.asJavaCollection(schema.toAttributes()).stream() - .map(Attribute::toAttribute).collect(Collectors.toList()); - return RowEncoder.apply(schema) - .resolveAndBind(JavaConverters.asScalaBufferConverter(attributes).asScala().toSeq(), - SimpleAnalyzer$.MODULE$); + return SparkAdapterSupport$.MODULE$.sparkAdapter().getCatalystExpressionUtils().getEncoder(schema); } public static InternalRow getInternalRow(Row row, ExpressionEncoder encoder) throws ClassNotFoundException, InvocationTargetException, IllegalAccessException, NoSuchMethodException { diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestAvroConversionUtils.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestAvroConversionUtils.scala index d42e28fb98104..70d74f0a0cea5 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestAvroConversionUtils.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestAvroConversionUtils.scala @@ -383,7 +383,7 @@ class TestAvroConversionUtils extends FunSuite with Matchers { } } ] - }} + } """ val expectedAvroSchema = new Schema.Parser().parse(expectedSchemaStr) diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/common/table/read/TestHoodieFileGroupReaderOnSpark.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/common/table/read/TestHoodieFileGroupReaderOnSpark.scala index 087ae50ff02cc..a5f25d8bd3406 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/common/table/read/TestHoodieFileGroupReaderOnSpark.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/common/table/read/TestHoodieFileGroupReaderOnSpark.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.{Dataset, HoodieInternalRowUtils, HoodieUnsafeUtils, import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.{HoodieSparkKryoRegistrar, SparkConf} import org.junit.jupiter.api.Assertions.assertEquals -import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.{AfterEach, BeforeEach} import java.util import scala.collection.JavaConversions._ @@ -68,6 +68,13 @@ class TestHoodieFileGroupReaderOnSpark extends TestHoodieFileGroupReaderBase[Int spark = SparkSession.builder.config(sparkConf).getOrCreate } + @AfterEach + def teardown() { + if (spark != null) { + spark.stop() + } + } + override def getHadoopConf: Configuration = { FSUtils.buildInlineConf(new Configuration) } diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestInsertTable.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestInsertTable.scala index 80fad42b247e0..e553c9b66cff3 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestInsertTable.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestInsertTable.scala @@ -537,6 +537,10 @@ class TestInsertTable extends HoodieSparkSqlTestBase { test("Test insert for uppercase table name") { withRecordType()(withTempDir{ tmp => val tableName = s"H_$generateTableName" + if (HoodieSparkUtils.gteqSpark3_5) { + // [SPARK-44284] Spark 3.5+ requires conf below to be case sensitive + spark.sql(s"set spark.sql.caseSensitive=true") + } spark.sql( s""" @@ -557,7 +561,7 @@ class TestInsertTable extends HoodieSparkSqlTestBase { .setBasePath(tmp.getCanonicalPath) .setConf(spark.sessionState.newHadoopConf()) .build() - assertResult(metaClient.getTableConfig.getTableName)(tableName) + assertResult(tableName)(metaClient.getTableConfig.getTableName) }) } @@ -575,7 +579,13 @@ class TestInsertTable extends HoodieSparkSqlTestBase { | tblproperties (primaryKey = 'id') | partitioned by (dt) """.stripMargin) - val tooManyDataColumnsErrorMsg = if (HoodieSparkUtils.gteqSpark3_4) { + val tooManyDataColumnsErrorMsg = if (HoodieSparkUtils.gteqSpark3_5) { + s""" + |[INSERT_COLUMN_ARITY_MISMATCH.TOO_MANY_DATA_COLUMNS] Cannot write to `spark_catalog`.`default`.`$tableName`, the reason is too many data columns: + |Table columns: `id`, `name`, `price`. + |Data columns: `1`, `a1`, `10`, `2021-06-20`. + |""".stripMargin + } else if (HoodieSparkUtils.gteqSpark3_4) { """ |too many data columns: |Table columns: 'id', 'name', 'price'. @@ -591,7 +601,13 @@ class TestInsertTable extends HoodieSparkSqlTestBase { checkExceptionContain(s"insert into $tableName partition(dt = '2021-06-20') select 1, 'a1', 10, '2021-06-20'")( tooManyDataColumnsErrorMsg) - val notEnoughDataColumnsErrorMsg = if (HoodieSparkUtils.gteqSpark3_4) { + val notEnoughDataColumnsErrorMsg = if (HoodieSparkUtils.gteqSpark3_5) { + s""" + |[INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS] Cannot write to `spark_catalog`.`default`.`$tableName`, the reason is not enough data columns: + |Table columns: `id`, `name`, `price`, `dt`. + |Data columns: `1`, `a1`, `10`. + |""".stripMargin + } else if (HoodieSparkUtils.gteqSpark3_4) { """ |not enough data columns: |Table columns: 'id', 'name', 'price', 'dt'. diff --git a/hudi-spark-datasource/hudi-spark2/pom.xml b/hudi-spark-datasource/hudi-spark2/pom.xml index d996057293e35..d52cf47755529 100644 --- a/hudi-spark-datasource/hudi-spark2/pom.xml +++ b/hudi-spark-datasource/hudi-spark2/pom.xml @@ -197,6 +197,14 @@ true + + org.apache.spark + spark-core_${scala.binary.version} + ${spark2.version} + provided + true + + org.apache.hudi diff --git a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/HoodieSpark2CatalystExpressionUtils.scala b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/HoodieSpark2CatalystExpressionUtils.scala index ea5841ecdf43a..337773db162a9 100644 --- a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/HoodieSpark2CatalystExpressionUtils.scala +++ b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/HoodieSpark2CatalystExpressionUtils.scala @@ -18,11 +18,16 @@ package org.apache.spark.sql import HoodieSparkTypeUtils.isCastPreservingOrdering +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.{Add, And, Attribute, AttributeReference, AttributeSet, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, Exp, Expm1, Expression, FromUTCTimestamp, FromUnixTime, Like, Log, Log10, Log1p, Log2, Lower, Multiply, Or, ParseToDate, ParseToTimestamp, ShiftLeft, ShiftRight, ToUTCTimestamp, ToUnixTimestamp, Upper} -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{DataType, StructType} object HoodieSpark2CatalystExpressionUtils extends HoodieCatalystExpressionUtils { + override def getEncoder(schema: StructType): ExpressionEncoder[Row] = { + RowEncoder.apply(schema).resolveAndBind() + } + // NOTE: This method has been borrowed from Spark 3.1 override def extractPredicatesWithinOutputSet(condition: Expression, outputSet: AttributeSet): Option[Expression] = condition match { diff --git a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/HoodieSpark2SchemaUtils.scala b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/HoodieSpark2SchemaUtils.scala index e2c1dc4a24449..beee0d293dfd4 100644 --- a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/HoodieSpark2SchemaUtils.scala +++ b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/HoodieSpark2SchemaUtils.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.SchemaUtils /** @@ -30,4 +32,8 @@ object HoodieSpark2SchemaUtils extends HoodieSchemaUtils { caseSensitiveAnalysis: Boolean): Unit = { SchemaUtils.checkColumnNameDuplication(columnNames, colType, caseSensitiveAnalysis) } + + override def toAttributes(struct: StructType): Seq[Attribute] = { + struct.toAttributes + } } diff --git a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/adapter/Spark2Adapter.scala b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/adapter/Spark2Adapter.scala index ec275a1d3fdc2..00e4d0c1ca911 100644 --- a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/adapter/Spark2Adapter.scala +++ b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/adapter/Spark2Adapter.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.adapter import org.apache.avro.Schema +import org.apache.hadoop.fs.FileStatus import org.apache.hadoop.fs.Path import org.apache.hudi.client.utils.SparkRowSerDe import org.apache.hudi.common.table.HoodieTableMetaClient @@ -26,8 +27,8 @@ import org.apache.hudi.{AvroConversionUtils, DefaultSource, Spark2HoodieFileScan import org.apache.spark.sql._ import org.apache.spark.sql.avro._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.RowEncoder -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, InterpretedPredicate} +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, InterpretedPredicate} import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical.{Command, DeleteFromTable, Join, LogicalPlan} @@ -91,7 +92,7 @@ class Spark2Adapter extends SparkAdapter { override def getAvroSchemaConverters: HoodieAvroSchemaConverters = HoodieSparkAvroSchemaConverters override def createSparkRowSerDe(schema: StructType): SparkRowSerDe = { - val encoder = RowEncoder(schema).resolveAndBind() + val encoder = getCatalystExpressionUtils.getEncoder(schema) new Spark2RowSerDe(encoder) } diff --git a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/execution/datasources/HoodieSpark2PartitionedFileUtils.scala b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/execution/datasources/HoodieSpark2PartitionedFileUtils.scala index 66c4722f6619a..99b0a58bb25a8 100644 --- a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/execution/datasources/HoodieSpark2PartitionedFileUtils.scala +++ b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/execution/datasources/HoodieSpark2PartitionedFileUtils.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.execution.datasources -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.spark.sql.catalyst.InternalRow /** - * Utils on Spark [[PartitionedFile]] for Spark 2.4. + * Utils on Spark [[PartitionedFile]] and [[PartitionDirectory]] for Spark 2.4. */ object HoodieSpark2PartitionedFileUtils extends HoodieSparkPartitionedFileUtils { override def getPathFromPartitionedFile(partitionedFile: PartitionedFile): Path = { @@ -40,4 +40,12 @@ object HoodieSpark2PartitionedFileUtils extends HoodieSparkPartitionedFileUtils length: Long): PartitionedFile = { PartitionedFile(partitionValues, filePath.toUri.toString, start, length) } + + override def toFileStatuses(partitionDirs: Seq[PartitionDirectory]): Seq[FileStatus] = { + partitionDirs.flatMap(_.files) + } + + override def newPartitionDirectory(internalRow: InternalRow, statuses: Seq[FileStatus]): PartitionDirectory = { + PartitionDirectory(internalRow, statuses) + } } diff --git a/hudi-spark-datasource/hudi-spark-common/src/test/java/org/apache/hudi/internal/HoodieBulkInsertInternalWriterTestBase.java b/hudi-spark-datasource/hudi-spark2/src/test/java/org/apache/hudi/internal/HoodieBulkInsertInternalWriterTestBase.java similarity index 100% rename from hudi-spark-datasource/hudi-spark-common/src/test/java/org/apache/hudi/internal/HoodieBulkInsertInternalWriterTestBase.java rename to hudi-spark-datasource/hudi-spark2/src/test/java/org/apache/hudi/internal/HoodieBulkInsertInternalWriterTestBase.java diff --git a/hudi-spark-datasource/hudi-spark3-common/src/main/java/org/apache/hudi/spark3/internal/ReflectUtil.java b/hudi-spark-datasource/hudi-spark3-common/src/main/java/org/apache/hudi/spark3/internal/ReflectUtil.java index d7a9a1f12241d..ad83720b0213b 100644 --- a/hudi-spark-datasource/hudi-spark3-common/src/main/java/org/apache/hudi/spark3/internal/ReflectUtil.java +++ b/hudi-spark-datasource/hudi-spark3-common/src/main/java/org/apache/hudi/spark3/internal/ReflectUtil.java @@ -33,9 +33,13 @@ public class ReflectUtil { public static InsertIntoStatement createInsertInto(LogicalPlan table, Map> partition, Seq userSpecifiedCols, - LogicalPlan query, boolean overwrite, boolean ifPartitionNotExists) { + LogicalPlan query, boolean overwrite, boolean ifPartitionNotExists, boolean byName) { try { - if (HoodieSparkUtils.isSpark3_0()) { + if (HoodieSparkUtils.gteqSpark3_5()) { + Constructor constructor = InsertIntoStatement.class.getConstructor( + LogicalPlan.class, Map.class, Seq.class, LogicalPlan.class, boolean.class, boolean.class, boolean.class); + return constructor.newInstance(table, partition, userSpecifiedCols, query, overwrite, ifPartitionNotExists, byName); + } else if (HoodieSparkUtils.isSpark3_0()) { Constructor constructor = InsertIntoStatement.class.getConstructor( LogicalPlan.class, Map.class, LogicalPlan.class, boolean.class, boolean.class); return constructor.newInstance(table, partition, query, overwrite, ifPartitionNotExists); diff --git a/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/adapter/BaseSpark3Adapter.scala b/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/adapter/BaseSpark3Adapter.scala index b2a9a529511ec..01e435b4f8d26 100644 --- a/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/adapter/BaseSpark3Adapter.scala +++ b/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/adapter/BaseSpark3Adapter.scala @@ -26,15 +26,14 @@ import org.apache.hudi.spark3.internal.ReflectUtil import org.apache.hudi.{AvroConversionUtils, DefaultSource, HoodieSparkUtils, Spark3RowSerDe} import org.apache.spark.internal.Logging import org.apache.spark.sql.avro.{HoodieAvroSchemaConverters, HoodieSparkAvroSchemaConverters} -import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Expression, InterpretedPredicate, Predicate} import org.apache.spark.sql.catalyst.util.DateFormatter import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.hudi.SparkAdapter import org.apache.spark.sql.sources.{BaseRelation, Filter} -import org.apache.spark.sql.{HoodieSpark3CatalogUtils, SQLContext, SparkSession} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.{ColumnVector, ColumnarBatch} +import org.apache.spark.sql.{HoodieSpark3CatalogUtils, SQLContext, SparkSession} import org.apache.spark.storage.StorageLevel import java.time.ZoneId @@ -57,8 +56,7 @@ abstract class BaseSpark3Adapter extends SparkAdapter with Logging { def getCatalogUtils: HoodieSpark3CatalogUtils override def createSparkRowSerDe(schema: StructType): SparkRowSerDe = { - val encoder = RowEncoder(schema).resolveAndBind() - new Spark3RowSerDe(encoder) + new Spark3RowSerDe(getCatalystExpressionUtils.getEncoder(schema)) } override def getAvroSchemaConverters: HoodieAvroSchemaConverters = HoodieSparkAvroSchemaConverters diff --git a/hudi-spark-datasource/hudi-spark3.0.x/pom.xml b/hudi-spark-datasource/hudi-spark3.0.x/pom.xml index 5842e9a8c5de8..e6d39bfd0821b 100644 --- a/hudi-spark-datasource/hudi-spark3.0.x/pom.xml +++ b/hudi-spark-datasource/hudi-spark3.0.x/pom.xml @@ -157,6 +157,14 @@ true + + org.apache.spark + spark-core_${scala.binary.version} + ${spark30.version} + provided + true + + com.fasterxml.jackson.core jackson-databind @@ -263,6 +271,13 @@ + + + + org.apache.parquet + parquet-avro + test + diff --git a/hudi-spark-datasource/hudi-spark3.0.x/src/main/scala/org/apache/spark/sql/HoodieSpark30CatalystExpressionUtils.scala b/hudi-spark-datasource/hudi-spark3.0.x/src/main/scala/org/apache/spark/sql/HoodieSpark30CatalystExpressionUtils.scala index ef3e8fdb6d16b..c4708be813b4a 100644 --- a/hudi-spark-datasource/hudi-spark3.0.x/src/main/scala/org/apache/spark/sql/HoodieSpark30CatalystExpressionUtils.scala +++ b/hudi-spark-datasource/hudi-spark3.0.x/src/main/scala/org/apache/spark/sql/HoodieSpark30CatalystExpressionUtils.scala @@ -19,11 +19,16 @@ package org.apache.spark.sql import org.apache.spark.sql.HoodieSparkTypeUtils.isCastPreservingOrdering +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.{AnsiCast, Attribute, AttributeReference, AttributeSet, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, Exp, Expm1, Expression, FromUTCTimestamp, FromUnixTime, Log, Log10, Log1p, Log2, Lower, Multiply, ParseToDate, ParseToTimestamp, PredicateHelper, ShiftLeft, ShiftRight, ToUTCTimestamp, ToUnixTimestamp, Upper} -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{DataType, StructType} object HoodieSpark30CatalystExpressionUtils extends HoodieSpark3CatalystExpressionUtils { + override def getEncoder(schema: StructType): ExpressionEncoder[Row] = { + RowEncoder.apply(schema).resolveAndBind() + } + override def matchCast(expr: Expression): Option[(Expression, DataType, Option[String])] = expr match { case Cast(child, dataType, timeZoneId) => Some((child, dataType, timeZoneId)) diff --git a/hudi-spark-datasource/hudi-spark3.0.x/src/main/scala/org/apache/spark/sql/HoodieSpark30SchemaUtils.scala b/hudi-spark-datasource/hudi-spark3.0.x/src/main/scala/org/apache/spark/sql/HoodieSpark30SchemaUtils.scala index 10775e11a4bbe..f66fd837c7e84 100644 --- a/hudi-spark-datasource/hudi-spark3.0.x/src/main/scala/org/apache/spark/sql/HoodieSpark30SchemaUtils.scala +++ b/hudi-spark-datasource/hudi-spark3.0.x/src/main/scala/org/apache/spark/sql/HoodieSpark30SchemaUtils.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.SchemaUtils /** @@ -30,4 +32,8 @@ object HoodieSpark30SchemaUtils extends HoodieSchemaUtils { caseSensitiveAnalysis: Boolean): Unit = { SchemaUtils.checkColumnNameDuplication(columnNames, colType, caseSensitiveAnalysis) } + + override def toAttributes(struct: StructType): Seq[Attribute] = { + struct.toAttributes + } } diff --git a/hudi-spark-datasource/hudi-spark3.0.x/src/main/scala/org/apache/spark/sql/execution/datasources/HoodieSpark30PartitionedFileUtils.scala b/hudi-spark-datasource/hudi-spark3.0.x/src/main/scala/org/apache/spark/sql/execution/datasources/HoodieSpark30PartitionedFileUtils.scala index 0abc17db05b40..5282e110c1fc3 100644 --- a/hudi-spark-datasource/hudi-spark3.0.x/src/main/scala/org/apache/spark/sql/execution/datasources/HoodieSpark30PartitionedFileUtils.scala +++ b/hudi-spark-datasource/hudi-spark3.0.x/src/main/scala/org/apache/spark/sql/execution/datasources/HoodieSpark30PartitionedFileUtils.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.execution.datasources -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.spark.sql.catalyst.InternalRow /** - * Utils on Spark [[PartitionedFile]] for Spark 3.0. + * Utils on Spark [[PartitionedFile]] and [[PartitionDirectory]] for Spark 3.0. */ object HoodieSpark30PartitionedFileUtils extends HoodieSparkPartitionedFileUtils { override def getPathFromPartitionedFile(partitionedFile: PartitionedFile): Path = { @@ -40,4 +40,12 @@ object HoodieSpark30PartitionedFileUtils extends HoodieSparkPartitionedFileUtils length: Long): PartitionedFile = { PartitionedFile(partitionValues, filePath.toUri.toString, start, length) } + + override def toFileStatuses(partitionDirs: Seq[PartitionDirectory]): Seq[FileStatus] = { + partitionDirs.flatMap(_.files) + } + + override def newPartitionDirectory(internalRow: InternalRow, statuses: Seq[FileStatus]): PartitionDirectory = { + PartitionDirectory(internalRow, statuses) + } } diff --git a/hudi-spark-datasource/hudi-spark3.0.x/src/test/java/org/apache/hudi/internal/HoodieBulkInsertInternalWriterTestBase.java b/hudi-spark-datasource/hudi-spark3.0.x/src/test/java/org/apache/hudi/internal/HoodieBulkInsertInternalWriterTestBase.java new file mode 100644 index 0000000000000..d4b0b0e764ed8 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.0.x/src/test/java/org/apache/hudi/internal/HoodieBulkInsertInternalWriterTestBase.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hudi.internal; + +import org.apache.hudi.DataSourceWriteOptions; +import org.apache.hudi.client.WriteStatus; +import org.apache.hudi.common.model.HoodieRecord; +import org.apache.hudi.common.model.HoodieRecord.HoodieMetadataField; +import org.apache.hudi.common.model.HoodieWriteStat; +import org.apache.hudi.common.table.HoodieTableConfig; +import org.apache.hudi.common.testutils.HoodieTestDataGenerator; +import org.apache.hudi.common.util.Option; +import org.apache.hudi.config.HoodieWriteConfig; +import org.apache.hudi.testutils.HoodieSparkClientTestHarness; +import org.apache.hudi.testutils.SparkDatasetTestUtils; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Base class for TestHoodieBulkInsertDataInternalWriter. + */ +public class HoodieBulkInsertInternalWriterTestBase extends HoodieSparkClientTestHarness { + + protected static final Random RANDOM = new Random(); + + @BeforeEach + public void setUp() throws Exception { + initSparkContexts(); + initPath(); + initFileSystem(); + initTestDataGenerator(); + initMetaClient(); + initTimelineService(); + } + + @AfterEach + public void tearDown() throws Exception { + cleanupResources(); + } + + protected HoodieWriteConfig getWriteConfig(boolean populateMetaFields) { + return getWriteConfig(populateMetaFields, DataSourceWriteOptions.HIVE_STYLE_PARTITIONING().defaultValue()); + } + + protected HoodieWriteConfig getWriteConfig(boolean populateMetaFields, String hiveStylePartitioningValue) { + Properties properties = new Properties(); + if (!populateMetaFields) { + properties.setProperty(DataSourceWriteOptions.RECORDKEY_FIELD().key(), SparkDatasetTestUtils.RECORD_KEY_FIELD_NAME); + properties.setProperty(DataSourceWriteOptions.PARTITIONPATH_FIELD().key(), SparkDatasetTestUtils.PARTITION_PATH_FIELD_NAME); + properties.setProperty(HoodieTableConfig.POPULATE_META_FIELDS.key(), "false"); + } + properties.setProperty(DataSourceWriteOptions.HIVE_STYLE_PARTITIONING().key(), hiveStylePartitioningValue); + return SparkDatasetTestUtils.getConfigBuilder(basePath, timelineServicePort).withProperties(properties).build(); + } + + protected void assertWriteStatuses(List writeStatuses, int batches, int size, + Option> fileAbsPaths, Option> fileNames) { + assertWriteStatuses(writeStatuses, batches, size, false, fileAbsPaths, fileNames, false); + } + + protected void assertWriteStatuses(List writeStatuses, int batches, int size, boolean areRecordsSorted, + Option> fileAbsPaths, Option> fileNames, boolean isHiveStylePartitioning) { + if (areRecordsSorted) { + assertEquals(batches, writeStatuses.size()); + } else { + assertEquals(Math.min(HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS.length, batches), writeStatuses.size()); + } + + Map sizeMap = new HashMap<>(); + if (!areRecordsSorted) { + // no of records are written per batch. Every 4th batch goes into same writeStatus. So, populating the size expected + // per write status + for (int i = 0; i < batches; i++) { + String partitionPath = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[i % 3]; + if (!sizeMap.containsKey(partitionPath)) { + sizeMap.put(partitionPath, 0L); + } + sizeMap.put(partitionPath, sizeMap.get(partitionPath) + size); + } + } + + int counter = 0; + for (WriteStatus writeStatus : writeStatuses) { + // verify write status + String actualPartitionPathFormat = isHiveStylePartitioning ? SparkDatasetTestUtils.PARTITION_PATH_FIELD_NAME + "=%s" : "%s"; + assertEquals(String.format(actualPartitionPathFormat, HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[counter % 3]), writeStatus.getPartitionPath()); + if (areRecordsSorted) { + assertEquals(writeStatus.getTotalRecords(), size); + } else { + assertEquals(writeStatus.getTotalRecords(), sizeMap.get(HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[counter % 3])); + } + assertNull(writeStatus.getGlobalError()); + assertEquals(writeStatus.getTotalErrorRecords(), 0); + assertEquals(writeStatus.getTotalErrorRecords(), 0); + assertFalse(writeStatus.hasErrors()); + assertNotNull(writeStatus.getFileId()); + String fileId = writeStatus.getFileId(); + if (fileAbsPaths.isPresent()) { + fileAbsPaths.get().add(basePath + "/" + writeStatus.getStat().getPath()); + } + if (fileNames.isPresent()) { + fileNames.get().add(writeStatus.getStat().getPath() + .substring(writeStatus.getStat().getPath().lastIndexOf('/') + 1)); + } + HoodieWriteStat writeStat = writeStatus.getStat(); + if (areRecordsSorted) { + assertEquals(size, writeStat.getNumInserts()); + assertEquals(size, writeStat.getNumWrites()); + } else { + assertEquals(sizeMap.get(HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[counter % 3]), writeStat.getNumInserts()); + assertEquals(sizeMap.get(HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[counter % 3]), writeStat.getNumWrites()); + } + assertEquals(fileId, writeStat.getFileId()); + assertEquals(String.format(actualPartitionPathFormat, HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[counter++ % 3]), writeStat.getPartitionPath()); + assertEquals(0, writeStat.getNumDeletes()); + assertEquals(0, writeStat.getNumUpdateWrites()); + assertEquals(0, writeStat.getTotalWriteErrors()); + } + } + + protected void assertOutput(Dataset expectedRows, Dataset actualRows, String instantTime, Option> fileNames, + boolean populateMetaColumns) { + if (populateMetaColumns) { + // verify 3 meta fields that are filled in within create handle + actualRows.collectAsList().forEach(entry -> { + assertEquals(entry.get(HoodieMetadataField.COMMIT_TIME_METADATA_FIELD.ordinal()).toString(), instantTime); + assertFalse(entry.isNullAt(HoodieMetadataField.FILENAME_METADATA_FIELD.ordinal())); + if (fileNames.isPresent()) { + assertTrue(fileNames.get().contains(entry.get(HoodieMetadataField.FILENAME_METADATA_FIELD.ordinal()))); + } + assertFalse(entry.isNullAt(HoodieMetadataField.COMMIT_SEQNO_METADATA_FIELD.ordinal())); + }); + + // after trimming 2 of the meta fields, rest of the fields should match + Dataset trimmedExpected = expectedRows.drop(HoodieRecord.COMMIT_SEQNO_METADATA_FIELD, HoodieRecord.COMMIT_TIME_METADATA_FIELD, HoodieRecord.FILENAME_METADATA_FIELD); + Dataset trimmedActual = actualRows.drop(HoodieRecord.COMMIT_SEQNO_METADATA_FIELD, HoodieRecord.COMMIT_TIME_METADATA_FIELD, HoodieRecord.FILENAME_METADATA_FIELD); + assertEquals(0, trimmedActual.except(trimmedExpected).count()); + } else { // operation = BULK_INSERT_APPEND_ONLY + // all meta columns are untouched + assertEquals(0, expectedRows.except(actualRows).count()); + } + } +} diff --git a/hudi-spark-datasource/hudi-spark3-common/src/test/java/org/apache/hudi/spark3/internal/TestHoodieBulkInsertDataInternalWriter.java b/hudi-spark-datasource/hudi-spark3.0.x/src/test/java/org/apache/hudi/spark3/internal/TestHoodieBulkInsertDataInternalWriter.java similarity index 100% rename from hudi-spark-datasource/hudi-spark3-common/src/test/java/org/apache/hudi/spark3/internal/TestHoodieBulkInsertDataInternalWriter.java rename to hudi-spark-datasource/hudi-spark3.0.x/src/test/java/org/apache/hudi/spark3/internal/TestHoodieBulkInsertDataInternalWriter.java diff --git a/hudi-spark-datasource/hudi-spark3-common/src/test/java/org/apache/hudi/spark3/internal/TestHoodieDataSourceInternalBatchWrite.java b/hudi-spark-datasource/hudi-spark3.0.x/src/test/java/org/apache/hudi/spark3/internal/TestHoodieDataSourceInternalBatchWrite.java similarity index 100% rename from hudi-spark-datasource/hudi-spark3-common/src/test/java/org/apache/hudi/spark3/internal/TestHoodieDataSourceInternalBatchWrite.java rename to hudi-spark-datasource/hudi-spark3.0.x/src/test/java/org/apache/hudi/spark3/internal/TestHoodieDataSourceInternalBatchWrite.java diff --git a/hudi-spark-datasource/hudi-spark3.1.x/pom.xml b/hudi-spark-datasource/hudi-spark3.1.x/pom.xml index 6e1bb8be2e81e..f1d08470b637d 100644 --- a/hudi-spark-datasource/hudi-spark3.1.x/pom.xml +++ b/hudi-spark-datasource/hudi-spark3.1.x/pom.xml @@ -157,6 +157,14 @@ true + + org.apache.spark + spark-core_${scala.binary.version} + ${spark31.version} + provided + true + + com.fasterxml.jackson.core jackson-databind @@ -263,6 +271,13 @@ + + + + org.apache.parquet + parquet-avro + test + diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/HoodieSpark31CatalystExpressionUtils.scala b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/HoodieSpark31CatalystExpressionUtils.scala index 33e338d3afe8a..3d32b206fd147 100644 --- a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/HoodieSpark31CatalystExpressionUtils.scala +++ b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/HoodieSpark31CatalystExpressionUtils.scala @@ -19,12 +19,16 @@ package org.apache.spark.sql import org.apache.spark.sql.HoodieSparkTypeUtils.isCastPreservingOrdering +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.{Add, AnsiCast, Attribute, AttributeReference, AttributeSet, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, Exp, Expm1, Expression, FromUTCTimestamp, FromUnixTime, Log, Log10, Log1p, Log2, Lower, Multiply, ParseToDate, ParseToTimestamp, PredicateHelper, ShiftLeft, ShiftRight, ToUTCTimestamp, ToUnixTimestamp, Upper} import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.types.DataType - +import org.apache.spark.sql.types.{DataType, StructType} object HoodieSpark31CatalystExpressionUtils extends HoodieSpark3CatalystExpressionUtils with PredicateHelper { + override def getEncoder(schema: StructType): ExpressionEncoder[Row] = { + RowEncoder.apply(schema).resolveAndBind() + } + override def normalizeExprs(exprs: Seq[Expression], attributes: Seq[Attribute]): Seq[Expression] = DataSourceStrategy.normalizeExprs(exprs, attributes) diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/HoodieSpark31SchemaUtils.scala b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/HoodieSpark31SchemaUtils.scala index c4753067f51e1..49388f5579135 100644 --- a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/HoodieSpark31SchemaUtils.scala +++ b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/HoodieSpark31SchemaUtils.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.SchemaUtils /** @@ -30,4 +32,8 @@ object HoodieSpark31SchemaUtils extends HoodieSchemaUtils { caseSensitiveAnalysis: Boolean): Unit = { SchemaUtils.checkColumnNameDuplication(columnNames, colType, caseSensitiveAnalysis) } + + override def toAttributes(struct: StructType): Seq[Attribute] = { + struct.toAttributes + } } diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/execution/datasources/HoodieSpark31PartitionedFileUtils.scala b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/execution/datasources/HoodieSpark31PartitionedFileUtils.scala index 5a359234631d8..3be432691f8fe 100644 --- a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/execution/datasources/HoodieSpark31PartitionedFileUtils.scala +++ b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/execution/datasources/HoodieSpark31PartitionedFileUtils.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.execution.datasources -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.spark.sql.catalyst.InternalRow /** - * Utils on Spark [[PartitionedFile]] for Spark 3.1. + * Utils on Spark [[PartitionedFile]] and [[PartitionDirectory]] for Spark 3.1. */ object HoodieSpark31PartitionedFileUtils extends HoodieSparkPartitionedFileUtils { override def getPathFromPartitionedFile(partitionedFile: PartitionedFile): Path = { @@ -40,4 +40,12 @@ object HoodieSpark31PartitionedFileUtils extends HoodieSparkPartitionedFileUtils length: Long): PartitionedFile = { PartitionedFile(partitionValues, filePath.toUri.toString, start, length) } + + override def toFileStatuses(partitionDirs: Seq[PartitionDirectory]): Seq[FileStatus] = { + partitionDirs.flatMap(_.files) + } + + override def newPartitionDirectory(internalRow: InternalRow, statuses: Seq[FileStatus]): PartitionDirectory = { + PartitionDirectory(internalRow, statuses) + } } diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/test/org/apache/hudi/internal/HoodieBulkInsertInternalWriterTestBase.java b/hudi-spark-datasource/hudi-spark3.1.x/src/test/org/apache/hudi/internal/HoodieBulkInsertInternalWriterTestBase.java new file mode 100644 index 0000000000000..d4b0b0e764ed8 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.1.x/src/test/org/apache/hudi/internal/HoodieBulkInsertInternalWriterTestBase.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hudi.internal; + +import org.apache.hudi.DataSourceWriteOptions; +import org.apache.hudi.client.WriteStatus; +import org.apache.hudi.common.model.HoodieRecord; +import org.apache.hudi.common.model.HoodieRecord.HoodieMetadataField; +import org.apache.hudi.common.model.HoodieWriteStat; +import org.apache.hudi.common.table.HoodieTableConfig; +import org.apache.hudi.common.testutils.HoodieTestDataGenerator; +import org.apache.hudi.common.util.Option; +import org.apache.hudi.config.HoodieWriteConfig; +import org.apache.hudi.testutils.HoodieSparkClientTestHarness; +import org.apache.hudi.testutils.SparkDatasetTestUtils; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Base class for TestHoodieBulkInsertDataInternalWriter. + */ +public class HoodieBulkInsertInternalWriterTestBase extends HoodieSparkClientTestHarness { + + protected static final Random RANDOM = new Random(); + + @BeforeEach + public void setUp() throws Exception { + initSparkContexts(); + initPath(); + initFileSystem(); + initTestDataGenerator(); + initMetaClient(); + initTimelineService(); + } + + @AfterEach + public void tearDown() throws Exception { + cleanupResources(); + } + + protected HoodieWriteConfig getWriteConfig(boolean populateMetaFields) { + return getWriteConfig(populateMetaFields, DataSourceWriteOptions.HIVE_STYLE_PARTITIONING().defaultValue()); + } + + protected HoodieWriteConfig getWriteConfig(boolean populateMetaFields, String hiveStylePartitioningValue) { + Properties properties = new Properties(); + if (!populateMetaFields) { + properties.setProperty(DataSourceWriteOptions.RECORDKEY_FIELD().key(), SparkDatasetTestUtils.RECORD_KEY_FIELD_NAME); + properties.setProperty(DataSourceWriteOptions.PARTITIONPATH_FIELD().key(), SparkDatasetTestUtils.PARTITION_PATH_FIELD_NAME); + properties.setProperty(HoodieTableConfig.POPULATE_META_FIELDS.key(), "false"); + } + properties.setProperty(DataSourceWriteOptions.HIVE_STYLE_PARTITIONING().key(), hiveStylePartitioningValue); + return SparkDatasetTestUtils.getConfigBuilder(basePath, timelineServicePort).withProperties(properties).build(); + } + + protected void assertWriteStatuses(List writeStatuses, int batches, int size, + Option> fileAbsPaths, Option> fileNames) { + assertWriteStatuses(writeStatuses, batches, size, false, fileAbsPaths, fileNames, false); + } + + protected void assertWriteStatuses(List writeStatuses, int batches, int size, boolean areRecordsSorted, + Option> fileAbsPaths, Option> fileNames, boolean isHiveStylePartitioning) { + if (areRecordsSorted) { + assertEquals(batches, writeStatuses.size()); + } else { + assertEquals(Math.min(HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS.length, batches), writeStatuses.size()); + } + + Map sizeMap = new HashMap<>(); + if (!areRecordsSorted) { + // no of records are written per batch. Every 4th batch goes into same writeStatus. So, populating the size expected + // per write status + for (int i = 0; i < batches; i++) { + String partitionPath = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[i % 3]; + if (!sizeMap.containsKey(partitionPath)) { + sizeMap.put(partitionPath, 0L); + } + sizeMap.put(partitionPath, sizeMap.get(partitionPath) + size); + } + } + + int counter = 0; + for (WriteStatus writeStatus : writeStatuses) { + // verify write status + String actualPartitionPathFormat = isHiveStylePartitioning ? SparkDatasetTestUtils.PARTITION_PATH_FIELD_NAME + "=%s" : "%s"; + assertEquals(String.format(actualPartitionPathFormat, HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[counter % 3]), writeStatus.getPartitionPath()); + if (areRecordsSorted) { + assertEquals(writeStatus.getTotalRecords(), size); + } else { + assertEquals(writeStatus.getTotalRecords(), sizeMap.get(HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[counter % 3])); + } + assertNull(writeStatus.getGlobalError()); + assertEquals(writeStatus.getTotalErrorRecords(), 0); + assertEquals(writeStatus.getTotalErrorRecords(), 0); + assertFalse(writeStatus.hasErrors()); + assertNotNull(writeStatus.getFileId()); + String fileId = writeStatus.getFileId(); + if (fileAbsPaths.isPresent()) { + fileAbsPaths.get().add(basePath + "/" + writeStatus.getStat().getPath()); + } + if (fileNames.isPresent()) { + fileNames.get().add(writeStatus.getStat().getPath() + .substring(writeStatus.getStat().getPath().lastIndexOf('/') + 1)); + } + HoodieWriteStat writeStat = writeStatus.getStat(); + if (areRecordsSorted) { + assertEquals(size, writeStat.getNumInserts()); + assertEquals(size, writeStat.getNumWrites()); + } else { + assertEquals(sizeMap.get(HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[counter % 3]), writeStat.getNumInserts()); + assertEquals(sizeMap.get(HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[counter % 3]), writeStat.getNumWrites()); + } + assertEquals(fileId, writeStat.getFileId()); + assertEquals(String.format(actualPartitionPathFormat, HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[counter++ % 3]), writeStat.getPartitionPath()); + assertEquals(0, writeStat.getNumDeletes()); + assertEquals(0, writeStat.getNumUpdateWrites()); + assertEquals(0, writeStat.getTotalWriteErrors()); + } + } + + protected void assertOutput(Dataset expectedRows, Dataset actualRows, String instantTime, Option> fileNames, + boolean populateMetaColumns) { + if (populateMetaColumns) { + // verify 3 meta fields that are filled in within create handle + actualRows.collectAsList().forEach(entry -> { + assertEquals(entry.get(HoodieMetadataField.COMMIT_TIME_METADATA_FIELD.ordinal()).toString(), instantTime); + assertFalse(entry.isNullAt(HoodieMetadataField.FILENAME_METADATA_FIELD.ordinal())); + if (fileNames.isPresent()) { + assertTrue(fileNames.get().contains(entry.get(HoodieMetadataField.FILENAME_METADATA_FIELD.ordinal()))); + } + assertFalse(entry.isNullAt(HoodieMetadataField.COMMIT_SEQNO_METADATA_FIELD.ordinal())); + }); + + // after trimming 2 of the meta fields, rest of the fields should match + Dataset trimmedExpected = expectedRows.drop(HoodieRecord.COMMIT_SEQNO_METADATA_FIELD, HoodieRecord.COMMIT_TIME_METADATA_FIELD, HoodieRecord.FILENAME_METADATA_FIELD); + Dataset trimmedActual = actualRows.drop(HoodieRecord.COMMIT_SEQNO_METADATA_FIELD, HoodieRecord.COMMIT_TIME_METADATA_FIELD, HoodieRecord.FILENAME_METADATA_FIELD); + assertEquals(0, trimmedActual.except(trimmedExpected).count()); + } else { // operation = BULK_INSERT_APPEND_ONLY + // all meta columns are untouched + assertEquals(0, expectedRows.except(actualRows).count()); + } + } +} diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/test/org/apache/hudi/spark3/internal/TestHoodieBulkInsertDataInternalWriter.java b/hudi-spark-datasource/hudi-spark3.1.x/src/test/org/apache/hudi/spark3/internal/TestHoodieBulkInsertDataInternalWriter.java new file mode 100644 index 0000000000000..206d4931b15e1 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.1.x/src/test/org/apache/hudi/spark3/internal/TestHoodieBulkInsertDataInternalWriter.java @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.hudi.spark3.internal; + +import org.apache.hudi.common.testutils.HoodieTestDataGenerator; +import org.apache.hudi.common.util.Option; +import org.apache.hudi.config.HoodieWriteConfig; +import org.apache.hudi.internal.HoodieBulkInsertInternalWriterTestBase; +import org.apache.hudi.table.HoodieSparkTable; +import org.apache.hudi.table.HoodieTable; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Stream; + +import static org.apache.hudi.testutils.SparkDatasetTestUtils.ENCODER; +import static org.apache.hudi.testutils.SparkDatasetTestUtils.STRUCT_TYPE; +import static org.apache.hudi.testutils.SparkDatasetTestUtils.getInternalRowWithError; +import static org.apache.hudi.testutils.SparkDatasetTestUtils.getRandomRows; +import static org.apache.hudi.testutils.SparkDatasetTestUtils.toInternalRows; +import static org.junit.jupiter.api.Assertions.fail; + +/** + * Unit tests {@link HoodieBulkInsertDataInternalWriter}. + */ +public class TestHoodieBulkInsertDataInternalWriter extends + HoodieBulkInsertInternalWriterTestBase { + + private static Stream configParams() { + Object[][] data = new Object[][] { + {true, true}, + {true, false}, + {false, true}, + {false, false} + }; + return Stream.of(data).map(Arguments::of); + } + + private static Stream bulkInsertTypeParams() { + Object[][] data = new Object[][] { + {true}, + {false} + }; + return Stream.of(data).map(Arguments::of); + } + + @ParameterizedTest + @MethodSource("configParams") + public void testDataInternalWriter(boolean sorted, boolean populateMetaFields) throws Exception { + // init config and table + HoodieWriteConfig cfg = getWriteConfig(populateMetaFields); + HoodieTable table = HoodieSparkTable.create(cfg, context, metaClient); + // execute N rounds + for (int i = 0; i < 2; i++) { + String instantTime = "00" + i; + // init writer + HoodieBulkInsertDataInternalWriter writer = new HoodieBulkInsertDataInternalWriter(table, cfg, instantTime, RANDOM.nextInt(100000), + RANDOM.nextLong(), STRUCT_TYPE, populateMetaFields, sorted); + + int size = 10 + RANDOM.nextInt(1000); + // write N rows to partition1, N rows to partition2 and N rows to partition3 ... Each batch should create a new RowCreateHandle and a new file + int batches = 3; + Dataset totalInputRows = null; + + for (int j = 0; j < batches; j++) { + String partitionPath = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[j % 3]; + Dataset inputRows = getRandomRows(sqlContext, size, partitionPath, false); + writeRows(inputRows, writer); + if (totalInputRows == null) { + totalInputRows = inputRows; + } else { + totalInputRows = totalInputRows.union(inputRows); + } + } + + HoodieWriterCommitMessage commitMetadata = (HoodieWriterCommitMessage) writer.commit(); + Option> fileAbsPaths = Option.of(new ArrayList<>()); + Option> fileNames = Option.of(new ArrayList<>()); + + // verify write statuses + assertWriteStatuses(commitMetadata.getWriteStatuses(), batches, size, sorted, fileAbsPaths, fileNames, false); + + // verify rows + Dataset result = sqlContext.read().parquet(fileAbsPaths.get().toArray(new String[0])); + assertOutput(totalInputRows, result, instantTime, fileNames, populateMetaFields); + } + } + + + /** + * Issue some corrupted or wrong schematized InternalRow after few valid InternalRows so that global error is thrown. write batch 1 of valid records write batch2 of invalid records which is expected + * to throw Global Error. Verify global error is set appropriately and only first batch of records are written to disk. + */ + @Test + public void testGlobalFailure() throws Exception { + // init config and table + HoodieWriteConfig cfg = getWriteConfig(true); + HoodieTable table = HoodieSparkTable.create(cfg, context, metaClient); + String partitionPath = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[0]; + + String instantTime = "001"; + HoodieBulkInsertDataInternalWriter writer = new HoodieBulkInsertDataInternalWriter(table, cfg, instantTime, RANDOM.nextInt(100000), + RANDOM.nextLong(), STRUCT_TYPE, true, false); + + int size = 10 + RANDOM.nextInt(100); + int totalFailures = 5; + // Generate first batch of valid rows + Dataset inputRows = getRandomRows(sqlContext, size / 2, partitionPath, false); + List internalRows = toInternalRows(inputRows, ENCODER); + + // generate some failures rows + for (int i = 0; i < totalFailures; i++) { + internalRows.add(getInternalRowWithError(partitionPath)); + } + + // generate 2nd batch of valid rows + Dataset inputRows2 = getRandomRows(sqlContext, size / 2, partitionPath, false); + internalRows.addAll(toInternalRows(inputRows2, ENCODER)); + + // issue writes + try { + for (InternalRow internalRow : internalRows) { + writer.write(internalRow); + } + fail("Should have failed"); + } catch (Throwable e) { + // expected + } + + HoodieWriterCommitMessage commitMetadata = (HoodieWriterCommitMessage) writer.commit(); + + Option> fileAbsPaths = Option.of(new ArrayList<>()); + Option> fileNames = Option.of(new ArrayList<>()); + // verify write statuses + assertWriteStatuses(commitMetadata.getWriteStatuses(), 1, size / 2, fileAbsPaths, fileNames); + + // verify rows + Dataset result = sqlContext.read().parquet(fileAbsPaths.get().toArray(new String[0])); + assertOutput(inputRows, result, instantTime, fileNames, true); + } + + private void writeRows(Dataset inputRows, HoodieBulkInsertDataInternalWriter writer) + throws Exception { + List internalRows = toInternalRows(inputRows, ENCODER); + // issue writes + for (InternalRow internalRow : internalRows) { + writer.write(internalRow); + } + } +} diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/test/org/apache/hudi/spark3/internal/TestHoodieDataSourceInternalBatchWrite.java b/hudi-spark-datasource/hudi-spark3.1.x/src/test/org/apache/hudi/spark3/internal/TestHoodieDataSourceInternalBatchWrite.java new file mode 100644 index 0000000000000..31d606de4a1ef --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.1.x/src/test/org/apache/hudi/spark3/internal/TestHoodieDataSourceInternalBatchWrite.java @@ -0,0 +1,331 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.hudi.spark3.internal; + +import org.apache.hudi.DataSourceWriteOptions; +import org.apache.hudi.common.model.HoodieCommitMetadata; +import org.apache.hudi.common.testutils.HoodieTestDataGenerator; +import org.apache.hudi.common.util.Option; +import org.apache.hudi.config.HoodieWriteConfig; +import org.apache.hudi.internal.HoodieBulkInsertInternalWriterTestBase; +import org.apache.hudi.table.HoodieSparkTable; +import org.apache.hudi.table.HoodieTable; +import org.apache.hudi.testutils.HoodieClientTestUtils; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.write.DataWriter; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; + +import static org.apache.hudi.testutils.SparkDatasetTestUtils.ENCODER; +import static org.apache.hudi.testutils.SparkDatasetTestUtils.STRUCT_TYPE; +import static org.apache.hudi.testutils.SparkDatasetTestUtils.getRandomRows; +import static org.apache.hudi.testutils.SparkDatasetTestUtils.toInternalRows; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Unit tests {@link HoodieDataSourceInternalBatchWrite}. + */ +public class TestHoodieDataSourceInternalBatchWrite extends + HoodieBulkInsertInternalWriterTestBase { + + private static Stream bulkInsertTypeParams() { + Object[][] data = new Object[][] { + {true}, + {false} + }; + return Stream.of(data).map(Arguments::of); + } + + @ParameterizedTest + @MethodSource("bulkInsertTypeParams") + public void testDataSourceWriter(boolean populateMetaFields) throws Exception { + testDataSourceWriterInternal(Collections.emptyMap(), Collections.emptyMap(), populateMetaFields); + } + + private void testDataSourceWriterInternal(Map extraMetadata, Map expectedExtraMetadata, boolean populateMetaFields) throws Exception { + // init config and table + HoodieWriteConfig cfg = getWriteConfig(populateMetaFields); + HoodieTable table = HoodieSparkTable.create(cfg, context, metaClient); + String instantTime = "001"; + // init writer + HoodieDataSourceInternalBatchWrite dataSourceInternalBatchWrite = + new HoodieDataSourceInternalBatchWrite(instantTime, cfg, STRUCT_TYPE, sqlContext.sparkSession(), hadoopConf, extraMetadata, populateMetaFields, false); + DataWriter writer = dataSourceInternalBatchWrite.createBatchWriterFactory(null).createWriter(0, RANDOM.nextLong()); + + String[] partitionPaths = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS; + List partitionPathsAbs = new ArrayList<>(); + for (String partitionPath : partitionPaths) { + partitionPathsAbs.add(basePath + "/" + partitionPath + "/*"); + } + + int size = 10 + RANDOM.nextInt(1000); + int batches = 5; + Dataset totalInputRows = null; + + for (int j = 0; j < batches; j++) { + String partitionPath = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[j % 3]; + Dataset inputRows = getRandomRows(sqlContext, size, partitionPath, false); + writeRows(inputRows, writer); + if (totalInputRows == null) { + totalInputRows = inputRows; + } else { + totalInputRows = totalInputRows.union(inputRows); + } + } + + HoodieWriterCommitMessage commitMetadata = (HoodieWriterCommitMessage) writer.commit(); + List commitMessages = new ArrayList<>(); + commitMessages.add(commitMetadata); + dataSourceInternalBatchWrite.commit(commitMessages.toArray(new HoodieWriterCommitMessage[0])); + + metaClient.reloadActiveTimeline(); + Dataset result = HoodieClientTestUtils.read(jsc, basePath, sqlContext, metaClient.getFs(), partitionPathsAbs.toArray(new String[0])); + // verify output + assertOutput(totalInputRows, result, instantTime, Option.empty(), populateMetaFields); + assertWriteStatuses(commitMessages.get(0).getWriteStatuses(), batches, size, Option.empty(), Option.empty()); + + // verify extra metadata + Option commitMetadataOption = HoodieClientTestUtils.getCommitMetadataForLatestInstant(metaClient); + assertTrue(commitMetadataOption.isPresent()); + Map actualExtraMetadata = new HashMap<>(); + commitMetadataOption.get().getExtraMetadata().entrySet().stream().filter(entry -> + !entry.getKey().equals(HoodieCommitMetadata.SCHEMA_KEY)).forEach(entry -> actualExtraMetadata.put(entry.getKey(), entry.getValue())); + assertEquals(actualExtraMetadata, expectedExtraMetadata); + } + + @Test + public void testDataSourceWriterExtraCommitMetadata() throws Exception { + String commitExtraMetaPrefix = "commit_extra_meta_"; + Map extraMeta = new HashMap<>(); + extraMeta.put(DataSourceWriteOptions.COMMIT_METADATA_KEYPREFIX().key(), commitExtraMetaPrefix); + extraMeta.put(commitExtraMetaPrefix + "a", "valA"); + extraMeta.put(commitExtraMetaPrefix + "b", "valB"); + extraMeta.put("commit_extra_c", "valC"); // should not be part of commit extra metadata + + Map expectedMetadata = new HashMap<>(); + expectedMetadata.putAll(extraMeta); + expectedMetadata.remove(DataSourceWriteOptions.COMMIT_METADATA_KEYPREFIX().key()); + expectedMetadata.remove("commit_extra_c"); + + testDataSourceWriterInternal(extraMeta, expectedMetadata, true); + } + + @Test + public void testDataSourceWriterEmptyExtraCommitMetadata() throws Exception { + String commitExtraMetaPrefix = "commit_extra_meta_"; + Map extraMeta = new HashMap<>(); + extraMeta.put(DataSourceWriteOptions.COMMIT_METADATA_KEYPREFIX().key(), commitExtraMetaPrefix); + extraMeta.put("keyA", "valA"); + extraMeta.put("keyB", "valB"); + extraMeta.put("commit_extra_c", "valC"); + // none of the keys has commit metadata key prefix. + testDataSourceWriterInternal(extraMeta, Collections.emptyMap(), true); + } + + @ParameterizedTest + @MethodSource("bulkInsertTypeParams") + public void testMultipleDataSourceWrites(boolean populateMetaFields) throws Exception { + // init config and table + HoodieWriteConfig cfg = getWriteConfig(populateMetaFields); + HoodieTable table = HoodieSparkTable.create(cfg, context, metaClient); + int partitionCounter = 0; + + // execute N rounds + for (int i = 0; i < 2; i++) { + String instantTime = "00" + i; + // init writer + HoodieDataSourceInternalBatchWrite dataSourceInternalBatchWrite = + new HoodieDataSourceInternalBatchWrite(instantTime, cfg, STRUCT_TYPE, sqlContext.sparkSession(), hadoopConf, Collections.emptyMap(), populateMetaFields, false); + List commitMessages = new ArrayList<>(); + Dataset totalInputRows = null; + DataWriter writer = dataSourceInternalBatchWrite.createBatchWriterFactory(null).createWriter(partitionCounter++, RANDOM.nextLong()); + + int size = 10 + RANDOM.nextInt(1000); + int batches = 3; // one batch per partition + + for (int j = 0; j < batches; j++) { + String partitionPath = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[j % 3]; + Dataset inputRows = getRandomRows(sqlContext, size, partitionPath, false); + writeRows(inputRows, writer); + if (totalInputRows == null) { + totalInputRows = inputRows; + } else { + totalInputRows = totalInputRows.union(inputRows); + } + } + + HoodieWriterCommitMessage commitMetadata = (HoodieWriterCommitMessage) writer.commit(); + commitMessages.add(commitMetadata); + dataSourceInternalBatchWrite.commit(commitMessages.toArray(new HoodieWriterCommitMessage[0])); + metaClient.reloadActiveTimeline(); + + Dataset result = HoodieClientTestUtils.readCommit(basePath, sqlContext, metaClient.getCommitTimeline(), instantTime, populateMetaFields); + + // verify output + assertOutput(totalInputRows, result, instantTime, Option.empty(), populateMetaFields); + assertWriteStatuses(commitMessages.get(0).getWriteStatuses(), batches, size, Option.empty(), Option.empty()); + } + } + + // Large writes are not required to be executed w/ regular CI jobs. Takes lot of running time. + @Disabled + @ParameterizedTest + @MethodSource("bulkInsertTypeParams") + public void testLargeWrites(boolean populateMetaFields) throws Exception { + // init config and table + HoodieWriteConfig cfg = getWriteConfig(populateMetaFields); + HoodieTable table = HoodieSparkTable.create(cfg, context, metaClient); + int partitionCounter = 0; + + // execute N rounds + for (int i = 0; i < 3; i++) { + String instantTime = "00" + i; + // init writer + HoodieDataSourceInternalBatchWrite dataSourceInternalBatchWrite = + new HoodieDataSourceInternalBatchWrite(instantTime, cfg, STRUCT_TYPE, sqlContext.sparkSession(), hadoopConf, Collections.emptyMap(), populateMetaFields, false); + List commitMessages = new ArrayList<>(); + Dataset totalInputRows = null; + DataWriter writer = dataSourceInternalBatchWrite.createBatchWriterFactory(null).createWriter(partitionCounter++, RANDOM.nextLong()); + + int size = 10000 + RANDOM.nextInt(10000); + int batches = 3; // one batch per partition + + for (int j = 0; j < batches; j++) { + String partitionPath = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[j % 3]; + Dataset inputRows = getRandomRows(sqlContext, size, partitionPath, false); + writeRows(inputRows, writer); + if (totalInputRows == null) { + totalInputRows = inputRows; + } else { + totalInputRows = totalInputRows.union(inputRows); + } + } + + HoodieWriterCommitMessage commitMetadata = (HoodieWriterCommitMessage) writer.commit(); + commitMessages.add(commitMetadata); + dataSourceInternalBatchWrite.commit(commitMessages.toArray(new HoodieWriterCommitMessage[0])); + metaClient.reloadActiveTimeline(); + + Dataset result = HoodieClientTestUtils.readCommit(basePath, sqlContext, metaClient.getCommitTimeline(), instantTime, + populateMetaFields); + + // verify output + assertOutput(totalInputRows, result, instantTime, Option.empty(), populateMetaFields); + assertWriteStatuses(commitMessages.get(0).getWriteStatuses(), batches, size, Option.empty(), Option.empty()); + } + } + + /** + * Tests that DataSourceWriter.abort() will abort the written records of interest write and commit batch1 write and abort batch2 Read of entire dataset should show only records from batch1. + * commit batch1 + * abort batch2 + * verify only records from batch1 is available to read + */ + @ParameterizedTest + @MethodSource("bulkInsertTypeParams") + public void testAbort(boolean populateMetaFields) throws Exception { + // init config and table + HoodieWriteConfig cfg = getWriteConfig(populateMetaFields); + HoodieTable table = HoodieSparkTable.create(cfg, context, metaClient); + String instantTime0 = "00" + 0; + // init writer + HoodieDataSourceInternalBatchWrite dataSourceInternalBatchWrite = + new HoodieDataSourceInternalBatchWrite(instantTime0, cfg, STRUCT_TYPE, sqlContext.sparkSession(), hadoopConf, Collections.emptyMap(), populateMetaFields, false); + DataWriter writer = dataSourceInternalBatchWrite.createBatchWriterFactory(null).createWriter(0, RANDOM.nextLong()); + + List partitionPaths = Arrays.asList(HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS); + List partitionPathsAbs = new ArrayList<>(); + for (String partitionPath : partitionPaths) { + partitionPathsAbs.add(basePath + "/" + partitionPath + "/*"); + } + + int size = 10 + RANDOM.nextInt(100); + int batches = 1; + Dataset totalInputRows = null; + + for (int j = 0; j < batches; j++) { + String partitionPath = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[j % 3]; + Dataset inputRows = getRandomRows(sqlContext, size, partitionPath, false); + writeRows(inputRows, writer); + if (totalInputRows == null) { + totalInputRows = inputRows; + } else { + totalInputRows = totalInputRows.union(inputRows); + } + } + + HoodieWriterCommitMessage commitMetadata = (HoodieWriterCommitMessage) writer.commit(); + List commitMessages = new ArrayList<>(); + commitMessages.add(commitMetadata); + // commit 1st batch + dataSourceInternalBatchWrite.commit(commitMessages.toArray(new HoodieWriterCommitMessage[0])); + metaClient.reloadActiveTimeline(); + Dataset result = HoodieClientTestUtils.read(jsc, basePath, sqlContext, metaClient.getFs(), partitionPathsAbs.toArray(new String[0])); + // verify rows + assertOutput(totalInputRows, result, instantTime0, Option.empty(), populateMetaFields); + assertWriteStatuses(commitMessages.get(0).getWriteStatuses(), batches, size, Option.empty(), Option.empty()); + + // 2nd batch. abort in the end + String instantTime1 = "00" + 1; + dataSourceInternalBatchWrite = + new HoodieDataSourceInternalBatchWrite(instantTime1, cfg, STRUCT_TYPE, sqlContext.sparkSession(), hadoopConf, Collections.emptyMap(), populateMetaFields, false); + writer = dataSourceInternalBatchWrite.createBatchWriterFactory(null).createWriter(1, RANDOM.nextLong()); + + for (int j = 0; j < batches; j++) { + String partitionPath = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[j % 3]; + Dataset inputRows = getRandomRows(sqlContext, size, partitionPath, false); + writeRows(inputRows, writer); + } + + commitMetadata = (HoodieWriterCommitMessage) writer.commit(); + commitMessages = new ArrayList<>(); + commitMessages.add(commitMetadata); + // commit 1st batch + dataSourceInternalBatchWrite.abort(commitMessages.toArray(new HoodieWriterCommitMessage[0])); + metaClient.reloadActiveTimeline(); + result = HoodieClientTestUtils.read(jsc, basePath, sqlContext, metaClient.getFs(), partitionPathsAbs.toArray(new String[0])); + // verify rows + // only rows from first batch should be present + assertOutput(totalInputRows, result, instantTime0, Option.empty(), populateMetaFields); + } + + private void writeRows(Dataset inputRows, DataWriter writer) throws Exception { + List internalRows = toInternalRows(inputRows, ENCODER); + // issue writes + for (InternalRow internalRow : internalRows) { + writer.write(internalRow); + } + } +} diff --git a/hudi-spark-datasource/hudi-spark3.2.x/pom.xml b/hudi-spark-datasource/hudi-spark3.2.x/pom.xml index 8f08c96a90c74..fae380c2e894c 100644 --- a/hudi-spark-datasource/hudi-spark3.2.x/pom.xml +++ b/hudi-spark-datasource/hudi-spark3.2.x/pom.xml @@ -196,12 +196,6 @@ ${spark32.version} provided true - - - * - * - - @@ -315,6 +309,8 @@ test-jar test + + org.apache.parquet parquet-avro diff --git a/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/HoodieSpark32CatalystExpressionUtils.scala b/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/HoodieSpark32CatalystExpressionUtils.scala index 9cd85ca8a53ef..1eaa99ac77f6d 100644 --- a/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/HoodieSpark32CatalystExpressionUtils.scala +++ b/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/HoodieSpark32CatalystExpressionUtils.scala @@ -18,12 +18,17 @@ package org.apache.spark.sql import org.apache.spark.sql.HoodieSparkTypeUtils.isCastPreservingOrdering +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.{Add, AnsiCast, Attribute, AttributeReference, AttributeSet, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, Exp, Expm1, Expression, FromUTCTimestamp, FromUnixTime, Log, Log10, Log1p, Log2, Lower, Multiply, ParseToDate, ParseToTimestamp, PredicateHelper, ShiftLeft, ShiftRight, ToUTCTimestamp, ToUnixTimestamp, Upper} import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{DataType, StructType} object HoodieSpark32CatalystExpressionUtils extends HoodieSpark3CatalystExpressionUtils with PredicateHelper { + override def getEncoder(schema: StructType): ExpressionEncoder[Row] = { + RowEncoder.apply(schema).resolveAndBind() + } + override def normalizeExprs(exprs: Seq[Expression], attributes: Seq[Attribute]): Seq[Expression] = DataSourceStrategy.normalizeExprs(exprs, attributes) diff --git a/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/HoodieSpark32SchemaUtils.scala b/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/HoodieSpark32SchemaUtils.scala index 03931067d6e50..b5127fe328f7e 100644 --- a/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/HoodieSpark32SchemaUtils.scala +++ b/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/HoodieSpark32SchemaUtils.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.SchemaUtils /** @@ -30,4 +32,8 @@ object HoodieSpark32SchemaUtils extends HoodieSchemaUtils { caseSensitiveAnalysis: Boolean): Unit = { SchemaUtils.checkColumnNameDuplication(columnNames, colType, caseSensitiveAnalysis) } + + override def toAttributes(struct: StructType): Seq[Attribute] = { + struct.toAttributes + } } diff --git a/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/execution/datasources/HoodieSpark32PartitionedFileUtils.scala b/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/execution/datasources/HoodieSpark32PartitionedFileUtils.scala index a5e4c04a17093..a9fac5d45ef7a 100644 --- a/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/execution/datasources/HoodieSpark32PartitionedFileUtils.scala +++ b/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/execution/datasources/HoodieSpark32PartitionedFileUtils.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.execution.datasources -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.spark.sql.catalyst.InternalRow /** - * Utils on Spark [[PartitionedFile]] for Spark 3.2. + * Utils on Spark [[PartitionedFile]] and [[PartitionDirectory]] for Spark 3.2. */ object HoodieSpark32PartitionedFileUtils extends HoodieSparkPartitionedFileUtils { override def getPathFromPartitionedFile(partitionedFile: PartitionedFile): Path = { @@ -40,4 +40,12 @@ object HoodieSpark32PartitionedFileUtils extends HoodieSparkPartitionedFileUtils length: Long): PartitionedFile = { PartitionedFile(partitionValues, filePath.toUri.toString, start, length) } + + override def toFileStatuses(partitionDirs: Seq[PartitionDirectory]): Seq[FileStatus] = { + partitionDirs.flatMap(_.files) + } + + override def newPartitionDirectory(internalRow: InternalRow, statuses: Seq[FileStatus]): PartitionDirectory = { + PartitionDirectory(internalRow, statuses) + } } diff --git a/hudi-spark-datasource/hudi-spark3.2plus-common/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark32PlusDataSourceUtils.scala b/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark32DataSourceUtils.scala similarity index 98% rename from hudi-spark-datasource/hudi-spark3.2plus-common/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark32PlusDataSourceUtils.scala rename to hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark32DataSourceUtils.scala index 5c3f5a976c25f..6d1c76380f216 100644 --- a/hudi-spark-datasource/hudi-spark3.2plus-common/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark32PlusDataSourceUtils.scala +++ b/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark32DataSourceUtils.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy import org.apache.spark.util.Utils -object Spark32PlusDataSourceUtils { +object Spark32DataSourceUtils { /** * NOTE: This method was copied from Spark 3.2.0, and is required to maintain runtime diff --git a/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark32LegacyHoodieParquetFileFormat.scala b/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark32LegacyHoodieParquetFileFormat.scala index c88c35b5eeb4e..6099e4ac25aca 100644 --- a/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark32LegacyHoodieParquetFileFormat.scala +++ b/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark32LegacyHoodieParquetFileFormat.scala @@ -185,7 +185,7 @@ class Spark32LegacyHoodieParquetFileFormat(private val shouldAppendPartitionValu } else { // Spark 3.2.0 val datetimeRebaseMode = - Spark32PlusDataSourceUtils.datetimeRebaseMode(footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead) + Spark32DataSourceUtils.datetimeRebaseMode(footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead) createParquetFilters( parquetSchema, pushDownDate, @@ -285,9 +285,9 @@ class Spark32LegacyHoodieParquetFileFormat(private val shouldAppendPartitionValu } else { // Spark 3.2.0 val datetimeRebaseMode = - Spark32PlusDataSourceUtils.datetimeRebaseMode(footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead) + Spark32DataSourceUtils.datetimeRebaseMode(footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead) val int96RebaseMode = - Spark32PlusDataSourceUtils.int96RebaseMode(footerFileMetaData.getKeyValueMetaData.get, int96RebaseModeInRead) + Spark32DataSourceUtils.int96RebaseMode(footerFileMetaData.getKeyValueMetaData.get, int96RebaseModeInRead) createVectorizedParquetRecordReader( convertTz.orNull, datetimeRebaseMode.toString, @@ -347,9 +347,9 @@ class Spark32LegacyHoodieParquetFileFormat(private val shouldAppendPartitionValu int96RebaseSpec) } else { val datetimeRebaseMode = - Spark32PlusDataSourceUtils.datetimeRebaseMode(footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead) + Spark32DataSourceUtils.datetimeRebaseMode(footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead) val int96RebaseMode = - Spark32PlusDataSourceUtils.int96RebaseMode(footerFileMetaData.getKeyValueMetaData.get, int96RebaseModeInRead) + Spark32DataSourceUtils.int96RebaseMode(footerFileMetaData.getKeyValueMetaData.get, int96RebaseModeInRead) createParquetReadSupport( convertTz, /* enableVectorizedReader = */ false, diff --git a/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSpark32Analysis.scala b/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSpark32Analysis.scala new file mode 100644 index 0000000000000..f139e8beb7fba --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSpark32Analysis.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hudi.analysis + +import org.apache.hudi.DefaultSource + +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.hudi.ProvidesHoodieConfig +import org.apache.spark.sql.hudi.catalog.HoodieInternalV2Table +import org.apache.spark.sql.{SQLContext, SparkSession} + +/** + * NOTE: PLEASE READ CAREFULLY + * + * Since Hudi relations don't currently implement DS V2 Read API, we have to fallback to V1 here. + * Such fallback will have considerable performance impact, therefore it's only performed in cases + * where V2 API have to be used. Currently only such use-case is using of Schema Evolution feature + * + * Check out HUDI-4178 for more details + */ +case class HoodieSpark32DataSourceV2ToV1Fallback(sparkSession: SparkSession) extends Rule[LogicalPlan] + with ProvidesHoodieConfig { + + override def apply(plan: LogicalPlan): LogicalPlan = plan match { + // The only place we're avoiding fallback is in [[AlterTableCommand]]s since + // current implementation relies on DSv2 features + case _: AlterTableCommand => plan + + // NOTE: Unfortunately, [[InsertIntoStatement]] is implemented in a way that doesn't expose + // target relation as a child (even though there's no good reason for that) + case iis@InsertIntoStatement(rv2@DataSourceV2Relation(v2Table: HoodieInternalV2Table, _, _, _, _), _, _, _, _, _) => + iis.copy(table = convertToV1(rv2, v2Table)) + + case _ => + plan.resolveOperatorsDown { + case rv2@DataSourceV2Relation(v2Table: HoodieInternalV2Table, _, _, _, _) => convertToV1(rv2, v2Table) + } + } + + private def convertToV1(rv2: DataSourceV2Relation, v2Table: HoodieInternalV2Table) = { + val output = rv2.output + val catalogTable = v2Table.catalogTable.map(_ => v2Table.v1Table) + val relation = new DefaultSource().createRelation(new SQLContext(sparkSession), + buildHoodieConfig(v2Table.hoodieCatalogTable), v2Table.hoodieCatalogTable.tableSchema) + + LogicalRelation(relation, output, catalogTable, isStreaming = false) + } +} diff --git a/hudi-spark-datasource/hudi-spark3.2.x/src/test/java/org/apache/hudi/internal/HoodieBulkInsertInternalWriterTestBase.java b/hudi-spark-datasource/hudi-spark3.2.x/src/test/java/org/apache/hudi/internal/HoodieBulkInsertInternalWriterTestBase.java new file mode 100644 index 0000000000000..d4b0b0e764ed8 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.2.x/src/test/java/org/apache/hudi/internal/HoodieBulkInsertInternalWriterTestBase.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hudi.internal; + +import org.apache.hudi.DataSourceWriteOptions; +import org.apache.hudi.client.WriteStatus; +import org.apache.hudi.common.model.HoodieRecord; +import org.apache.hudi.common.model.HoodieRecord.HoodieMetadataField; +import org.apache.hudi.common.model.HoodieWriteStat; +import org.apache.hudi.common.table.HoodieTableConfig; +import org.apache.hudi.common.testutils.HoodieTestDataGenerator; +import org.apache.hudi.common.util.Option; +import org.apache.hudi.config.HoodieWriteConfig; +import org.apache.hudi.testutils.HoodieSparkClientTestHarness; +import org.apache.hudi.testutils.SparkDatasetTestUtils; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Base class for TestHoodieBulkInsertDataInternalWriter. + */ +public class HoodieBulkInsertInternalWriterTestBase extends HoodieSparkClientTestHarness { + + protected static final Random RANDOM = new Random(); + + @BeforeEach + public void setUp() throws Exception { + initSparkContexts(); + initPath(); + initFileSystem(); + initTestDataGenerator(); + initMetaClient(); + initTimelineService(); + } + + @AfterEach + public void tearDown() throws Exception { + cleanupResources(); + } + + protected HoodieWriteConfig getWriteConfig(boolean populateMetaFields) { + return getWriteConfig(populateMetaFields, DataSourceWriteOptions.HIVE_STYLE_PARTITIONING().defaultValue()); + } + + protected HoodieWriteConfig getWriteConfig(boolean populateMetaFields, String hiveStylePartitioningValue) { + Properties properties = new Properties(); + if (!populateMetaFields) { + properties.setProperty(DataSourceWriteOptions.RECORDKEY_FIELD().key(), SparkDatasetTestUtils.RECORD_KEY_FIELD_NAME); + properties.setProperty(DataSourceWriteOptions.PARTITIONPATH_FIELD().key(), SparkDatasetTestUtils.PARTITION_PATH_FIELD_NAME); + properties.setProperty(HoodieTableConfig.POPULATE_META_FIELDS.key(), "false"); + } + properties.setProperty(DataSourceWriteOptions.HIVE_STYLE_PARTITIONING().key(), hiveStylePartitioningValue); + return SparkDatasetTestUtils.getConfigBuilder(basePath, timelineServicePort).withProperties(properties).build(); + } + + protected void assertWriteStatuses(List writeStatuses, int batches, int size, + Option> fileAbsPaths, Option> fileNames) { + assertWriteStatuses(writeStatuses, batches, size, false, fileAbsPaths, fileNames, false); + } + + protected void assertWriteStatuses(List writeStatuses, int batches, int size, boolean areRecordsSorted, + Option> fileAbsPaths, Option> fileNames, boolean isHiveStylePartitioning) { + if (areRecordsSorted) { + assertEquals(batches, writeStatuses.size()); + } else { + assertEquals(Math.min(HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS.length, batches), writeStatuses.size()); + } + + Map sizeMap = new HashMap<>(); + if (!areRecordsSorted) { + // no of records are written per batch. Every 4th batch goes into same writeStatus. So, populating the size expected + // per write status + for (int i = 0; i < batches; i++) { + String partitionPath = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[i % 3]; + if (!sizeMap.containsKey(partitionPath)) { + sizeMap.put(partitionPath, 0L); + } + sizeMap.put(partitionPath, sizeMap.get(partitionPath) + size); + } + } + + int counter = 0; + for (WriteStatus writeStatus : writeStatuses) { + // verify write status + String actualPartitionPathFormat = isHiveStylePartitioning ? SparkDatasetTestUtils.PARTITION_PATH_FIELD_NAME + "=%s" : "%s"; + assertEquals(String.format(actualPartitionPathFormat, HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[counter % 3]), writeStatus.getPartitionPath()); + if (areRecordsSorted) { + assertEquals(writeStatus.getTotalRecords(), size); + } else { + assertEquals(writeStatus.getTotalRecords(), sizeMap.get(HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[counter % 3])); + } + assertNull(writeStatus.getGlobalError()); + assertEquals(writeStatus.getTotalErrorRecords(), 0); + assertEquals(writeStatus.getTotalErrorRecords(), 0); + assertFalse(writeStatus.hasErrors()); + assertNotNull(writeStatus.getFileId()); + String fileId = writeStatus.getFileId(); + if (fileAbsPaths.isPresent()) { + fileAbsPaths.get().add(basePath + "/" + writeStatus.getStat().getPath()); + } + if (fileNames.isPresent()) { + fileNames.get().add(writeStatus.getStat().getPath() + .substring(writeStatus.getStat().getPath().lastIndexOf('/') + 1)); + } + HoodieWriteStat writeStat = writeStatus.getStat(); + if (areRecordsSorted) { + assertEquals(size, writeStat.getNumInserts()); + assertEquals(size, writeStat.getNumWrites()); + } else { + assertEquals(sizeMap.get(HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[counter % 3]), writeStat.getNumInserts()); + assertEquals(sizeMap.get(HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[counter % 3]), writeStat.getNumWrites()); + } + assertEquals(fileId, writeStat.getFileId()); + assertEquals(String.format(actualPartitionPathFormat, HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[counter++ % 3]), writeStat.getPartitionPath()); + assertEquals(0, writeStat.getNumDeletes()); + assertEquals(0, writeStat.getNumUpdateWrites()); + assertEquals(0, writeStat.getTotalWriteErrors()); + } + } + + protected void assertOutput(Dataset expectedRows, Dataset actualRows, String instantTime, Option> fileNames, + boolean populateMetaColumns) { + if (populateMetaColumns) { + // verify 3 meta fields that are filled in within create handle + actualRows.collectAsList().forEach(entry -> { + assertEquals(entry.get(HoodieMetadataField.COMMIT_TIME_METADATA_FIELD.ordinal()).toString(), instantTime); + assertFalse(entry.isNullAt(HoodieMetadataField.FILENAME_METADATA_FIELD.ordinal())); + if (fileNames.isPresent()) { + assertTrue(fileNames.get().contains(entry.get(HoodieMetadataField.FILENAME_METADATA_FIELD.ordinal()))); + } + assertFalse(entry.isNullAt(HoodieMetadataField.COMMIT_SEQNO_METADATA_FIELD.ordinal())); + }); + + // after trimming 2 of the meta fields, rest of the fields should match + Dataset trimmedExpected = expectedRows.drop(HoodieRecord.COMMIT_SEQNO_METADATA_FIELD, HoodieRecord.COMMIT_TIME_METADATA_FIELD, HoodieRecord.FILENAME_METADATA_FIELD); + Dataset trimmedActual = actualRows.drop(HoodieRecord.COMMIT_SEQNO_METADATA_FIELD, HoodieRecord.COMMIT_TIME_METADATA_FIELD, HoodieRecord.FILENAME_METADATA_FIELD); + assertEquals(0, trimmedActual.except(trimmedExpected).count()); + } else { // operation = BULK_INSERT_APPEND_ONLY + // all meta columns are untouched + assertEquals(0, expectedRows.except(actualRows).count()); + } + } +} diff --git a/hudi-spark-datasource/hudi-spark3.2.x/src/test/java/org/apache/hudi/spark3/internal/TestHoodieBulkInsertDataInternalWriter.java b/hudi-spark-datasource/hudi-spark3.2.x/src/test/java/org/apache/hudi/spark3/internal/TestHoodieBulkInsertDataInternalWriter.java new file mode 100644 index 0000000000000..206d4931b15e1 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.2.x/src/test/java/org/apache/hudi/spark3/internal/TestHoodieBulkInsertDataInternalWriter.java @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.hudi.spark3.internal; + +import org.apache.hudi.common.testutils.HoodieTestDataGenerator; +import org.apache.hudi.common.util.Option; +import org.apache.hudi.config.HoodieWriteConfig; +import org.apache.hudi.internal.HoodieBulkInsertInternalWriterTestBase; +import org.apache.hudi.table.HoodieSparkTable; +import org.apache.hudi.table.HoodieTable; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Stream; + +import static org.apache.hudi.testutils.SparkDatasetTestUtils.ENCODER; +import static org.apache.hudi.testutils.SparkDatasetTestUtils.STRUCT_TYPE; +import static org.apache.hudi.testutils.SparkDatasetTestUtils.getInternalRowWithError; +import static org.apache.hudi.testutils.SparkDatasetTestUtils.getRandomRows; +import static org.apache.hudi.testutils.SparkDatasetTestUtils.toInternalRows; +import static org.junit.jupiter.api.Assertions.fail; + +/** + * Unit tests {@link HoodieBulkInsertDataInternalWriter}. + */ +public class TestHoodieBulkInsertDataInternalWriter extends + HoodieBulkInsertInternalWriterTestBase { + + private static Stream configParams() { + Object[][] data = new Object[][] { + {true, true}, + {true, false}, + {false, true}, + {false, false} + }; + return Stream.of(data).map(Arguments::of); + } + + private static Stream bulkInsertTypeParams() { + Object[][] data = new Object[][] { + {true}, + {false} + }; + return Stream.of(data).map(Arguments::of); + } + + @ParameterizedTest + @MethodSource("configParams") + public void testDataInternalWriter(boolean sorted, boolean populateMetaFields) throws Exception { + // init config and table + HoodieWriteConfig cfg = getWriteConfig(populateMetaFields); + HoodieTable table = HoodieSparkTable.create(cfg, context, metaClient); + // execute N rounds + for (int i = 0; i < 2; i++) { + String instantTime = "00" + i; + // init writer + HoodieBulkInsertDataInternalWriter writer = new HoodieBulkInsertDataInternalWriter(table, cfg, instantTime, RANDOM.nextInt(100000), + RANDOM.nextLong(), STRUCT_TYPE, populateMetaFields, sorted); + + int size = 10 + RANDOM.nextInt(1000); + // write N rows to partition1, N rows to partition2 and N rows to partition3 ... Each batch should create a new RowCreateHandle and a new file + int batches = 3; + Dataset totalInputRows = null; + + for (int j = 0; j < batches; j++) { + String partitionPath = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[j % 3]; + Dataset inputRows = getRandomRows(sqlContext, size, partitionPath, false); + writeRows(inputRows, writer); + if (totalInputRows == null) { + totalInputRows = inputRows; + } else { + totalInputRows = totalInputRows.union(inputRows); + } + } + + HoodieWriterCommitMessage commitMetadata = (HoodieWriterCommitMessage) writer.commit(); + Option> fileAbsPaths = Option.of(new ArrayList<>()); + Option> fileNames = Option.of(new ArrayList<>()); + + // verify write statuses + assertWriteStatuses(commitMetadata.getWriteStatuses(), batches, size, sorted, fileAbsPaths, fileNames, false); + + // verify rows + Dataset result = sqlContext.read().parquet(fileAbsPaths.get().toArray(new String[0])); + assertOutput(totalInputRows, result, instantTime, fileNames, populateMetaFields); + } + } + + + /** + * Issue some corrupted or wrong schematized InternalRow after few valid InternalRows so that global error is thrown. write batch 1 of valid records write batch2 of invalid records which is expected + * to throw Global Error. Verify global error is set appropriately and only first batch of records are written to disk. + */ + @Test + public void testGlobalFailure() throws Exception { + // init config and table + HoodieWriteConfig cfg = getWriteConfig(true); + HoodieTable table = HoodieSparkTable.create(cfg, context, metaClient); + String partitionPath = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[0]; + + String instantTime = "001"; + HoodieBulkInsertDataInternalWriter writer = new HoodieBulkInsertDataInternalWriter(table, cfg, instantTime, RANDOM.nextInt(100000), + RANDOM.nextLong(), STRUCT_TYPE, true, false); + + int size = 10 + RANDOM.nextInt(100); + int totalFailures = 5; + // Generate first batch of valid rows + Dataset inputRows = getRandomRows(sqlContext, size / 2, partitionPath, false); + List internalRows = toInternalRows(inputRows, ENCODER); + + // generate some failures rows + for (int i = 0; i < totalFailures; i++) { + internalRows.add(getInternalRowWithError(partitionPath)); + } + + // generate 2nd batch of valid rows + Dataset inputRows2 = getRandomRows(sqlContext, size / 2, partitionPath, false); + internalRows.addAll(toInternalRows(inputRows2, ENCODER)); + + // issue writes + try { + for (InternalRow internalRow : internalRows) { + writer.write(internalRow); + } + fail("Should have failed"); + } catch (Throwable e) { + // expected + } + + HoodieWriterCommitMessage commitMetadata = (HoodieWriterCommitMessage) writer.commit(); + + Option> fileAbsPaths = Option.of(new ArrayList<>()); + Option> fileNames = Option.of(new ArrayList<>()); + // verify write statuses + assertWriteStatuses(commitMetadata.getWriteStatuses(), 1, size / 2, fileAbsPaths, fileNames); + + // verify rows + Dataset result = sqlContext.read().parquet(fileAbsPaths.get().toArray(new String[0])); + assertOutput(inputRows, result, instantTime, fileNames, true); + } + + private void writeRows(Dataset inputRows, HoodieBulkInsertDataInternalWriter writer) + throws Exception { + List internalRows = toInternalRows(inputRows, ENCODER); + // issue writes + for (InternalRow internalRow : internalRows) { + writer.write(internalRow); + } + } +} diff --git a/hudi-spark-datasource/hudi-spark3.2.x/src/test/java/org/apache/hudi/spark3/internal/TestHoodieDataSourceInternalBatchWrite.java b/hudi-spark-datasource/hudi-spark3.2.x/src/test/java/org/apache/hudi/spark3/internal/TestHoodieDataSourceInternalBatchWrite.java new file mode 100644 index 0000000000000..31d606de4a1ef --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.2.x/src/test/java/org/apache/hudi/spark3/internal/TestHoodieDataSourceInternalBatchWrite.java @@ -0,0 +1,331 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.hudi.spark3.internal; + +import org.apache.hudi.DataSourceWriteOptions; +import org.apache.hudi.common.model.HoodieCommitMetadata; +import org.apache.hudi.common.testutils.HoodieTestDataGenerator; +import org.apache.hudi.common.util.Option; +import org.apache.hudi.config.HoodieWriteConfig; +import org.apache.hudi.internal.HoodieBulkInsertInternalWriterTestBase; +import org.apache.hudi.table.HoodieSparkTable; +import org.apache.hudi.table.HoodieTable; +import org.apache.hudi.testutils.HoodieClientTestUtils; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.write.DataWriter; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; + +import static org.apache.hudi.testutils.SparkDatasetTestUtils.ENCODER; +import static org.apache.hudi.testutils.SparkDatasetTestUtils.STRUCT_TYPE; +import static org.apache.hudi.testutils.SparkDatasetTestUtils.getRandomRows; +import static org.apache.hudi.testutils.SparkDatasetTestUtils.toInternalRows; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Unit tests {@link HoodieDataSourceInternalBatchWrite}. + */ +public class TestHoodieDataSourceInternalBatchWrite extends + HoodieBulkInsertInternalWriterTestBase { + + private static Stream bulkInsertTypeParams() { + Object[][] data = new Object[][] { + {true}, + {false} + }; + return Stream.of(data).map(Arguments::of); + } + + @ParameterizedTest + @MethodSource("bulkInsertTypeParams") + public void testDataSourceWriter(boolean populateMetaFields) throws Exception { + testDataSourceWriterInternal(Collections.emptyMap(), Collections.emptyMap(), populateMetaFields); + } + + private void testDataSourceWriterInternal(Map extraMetadata, Map expectedExtraMetadata, boolean populateMetaFields) throws Exception { + // init config and table + HoodieWriteConfig cfg = getWriteConfig(populateMetaFields); + HoodieTable table = HoodieSparkTable.create(cfg, context, metaClient); + String instantTime = "001"; + // init writer + HoodieDataSourceInternalBatchWrite dataSourceInternalBatchWrite = + new HoodieDataSourceInternalBatchWrite(instantTime, cfg, STRUCT_TYPE, sqlContext.sparkSession(), hadoopConf, extraMetadata, populateMetaFields, false); + DataWriter writer = dataSourceInternalBatchWrite.createBatchWriterFactory(null).createWriter(0, RANDOM.nextLong()); + + String[] partitionPaths = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS; + List partitionPathsAbs = new ArrayList<>(); + for (String partitionPath : partitionPaths) { + partitionPathsAbs.add(basePath + "/" + partitionPath + "/*"); + } + + int size = 10 + RANDOM.nextInt(1000); + int batches = 5; + Dataset totalInputRows = null; + + for (int j = 0; j < batches; j++) { + String partitionPath = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[j % 3]; + Dataset inputRows = getRandomRows(sqlContext, size, partitionPath, false); + writeRows(inputRows, writer); + if (totalInputRows == null) { + totalInputRows = inputRows; + } else { + totalInputRows = totalInputRows.union(inputRows); + } + } + + HoodieWriterCommitMessage commitMetadata = (HoodieWriterCommitMessage) writer.commit(); + List commitMessages = new ArrayList<>(); + commitMessages.add(commitMetadata); + dataSourceInternalBatchWrite.commit(commitMessages.toArray(new HoodieWriterCommitMessage[0])); + + metaClient.reloadActiveTimeline(); + Dataset result = HoodieClientTestUtils.read(jsc, basePath, sqlContext, metaClient.getFs(), partitionPathsAbs.toArray(new String[0])); + // verify output + assertOutput(totalInputRows, result, instantTime, Option.empty(), populateMetaFields); + assertWriteStatuses(commitMessages.get(0).getWriteStatuses(), batches, size, Option.empty(), Option.empty()); + + // verify extra metadata + Option commitMetadataOption = HoodieClientTestUtils.getCommitMetadataForLatestInstant(metaClient); + assertTrue(commitMetadataOption.isPresent()); + Map actualExtraMetadata = new HashMap<>(); + commitMetadataOption.get().getExtraMetadata().entrySet().stream().filter(entry -> + !entry.getKey().equals(HoodieCommitMetadata.SCHEMA_KEY)).forEach(entry -> actualExtraMetadata.put(entry.getKey(), entry.getValue())); + assertEquals(actualExtraMetadata, expectedExtraMetadata); + } + + @Test + public void testDataSourceWriterExtraCommitMetadata() throws Exception { + String commitExtraMetaPrefix = "commit_extra_meta_"; + Map extraMeta = new HashMap<>(); + extraMeta.put(DataSourceWriteOptions.COMMIT_METADATA_KEYPREFIX().key(), commitExtraMetaPrefix); + extraMeta.put(commitExtraMetaPrefix + "a", "valA"); + extraMeta.put(commitExtraMetaPrefix + "b", "valB"); + extraMeta.put("commit_extra_c", "valC"); // should not be part of commit extra metadata + + Map expectedMetadata = new HashMap<>(); + expectedMetadata.putAll(extraMeta); + expectedMetadata.remove(DataSourceWriteOptions.COMMIT_METADATA_KEYPREFIX().key()); + expectedMetadata.remove("commit_extra_c"); + + testDataSourceWriterInternal(extraMeta, expectedMetadata, true); + } + + @Test + public void testDataSourceWriterEmptyExtraCommitMetadata() throws Exception { + String commitExtraMetaPrefix = "commit_extra_meta_"; + Map extraMeta = new HashMap<>(); + extraMeta.put(DataSourceWriteOptions.COMMIT_METADATA_KEYPREFIX().key(), commitExtraMetaPrefix); + extraMeta.put("keyA", "valA"); + extraMeta.put("keyB", "valB"); + extraMeta.put("commit_extra_c", "valC"); + // none of the keys has commit metadata key prefix. + testDataSourceWriterInternal(extraMeta, Collections.emptyMap(), true); + } + + @ParameterizedTest + @MethodSource("bulkInsertTypeParams") + public void testMultipleDataSourceWrites(boolean populateMetaFields) throws Exception { + // init config and table + HoodieWriteConfig cfg = getWriteConfig(populateMetaFields); + HoodieTable table = HoodieSparkTable.create(cfg, context, metaClient); + int partitionCounter = 0; + + // execute N rounds + for (int i = 0; i < 2; i++) { + String instantTime = "00" + i; + // init writer + HoodieDataSourceInternalBatchWrite dataSourceInternalBatchWrite = + new HoodieDataSourceInternalBatchWrite(instantTime, cfg, STRUCT_TYPE, sqlContext.sparkSession(), hadoopConf, Collections.emptyMap(), populateMetaFields, false); + List commitMessages = new ArrayList<>(); + Dataset totalInputRows = null; + DataWriter writer = dataSourceInternalBatchWrite.createBatchWriterFactory(null).createWriter(partitionCounter++, RANDOM.nextLong()); + + int size = 10 + RANDOM.nextInt(1000); + int batches = 3; // one batch per partition + + for (int j = 0; j < batches; j++) { + String partitionPath = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[j % 3]; + Dataset inputRows = getRandomRows(sqlContext, size, partitionPath, false); + writeRows(inputRows, writer); + if (totalInputRows == null) { + totalInputRows = inputRows; + } else { + totalInputRows = totalInputRows.union(inputRows); + } + } + + HoodieWriterCommitMessage commitMetadata = (HoodieWriterCommitMessage) writer.commit(); + commitMessages.add(commitMetadata); + dataSourceInternalBatchWrite.commit(commitMessages.toArray(new HoodieWriterCommitMessage[0])); + metaClient.reloadActiveTimeline(); + + Dataset result = HoodieClientTestUtils.readCommit(basePath, sqlContext, metaClient.getCommitTimeline(), instantTime, populateMetaFields); + + // verify output + assertOutput(totalInputRows, result, instantTime, Option.empty(), populateMetaFields); + assertWriteStatuses(commitMessages.get(0).getWriteStatuses(), batches, size, Option.empty(), Option.empty()); + } + } + + // Large writes are not required to be executed w/ regular CI jobs. Takes lot of running time. + @Disabled + @ParameterizedTest + @MethodSource("bulkInsertTypeParams") + public void testLargeWrites(boolean populateMetaFields) throws Exception { + // init config and table + HoodieWriteConfig cfg = getWriteConfig(populateMetaFields); + HoodieTable table = HoodieSparkTable.create(cfg, context, metaClient); + int partitionCounter = 0; + + // execute N rounds + for (int i = 0; i < 3; i++) { + String instantTime = "00" + i; + // init writer + HoodieDataSourceInternalBatchWrite dataSourceInternalBatchWrite = + new HoodieDataSourceInternalBatchWrite(instantTime, cfg, STRUCT_TYPE, sqlContext.sparkSession(), hadoopConf, Collections.emptyMap(), populateMetaFields, false); + List commitMessages = new ArrayList<>(); + Dataset totalInputRows = null; + DataWriter writer = dataSourceInternalBatchWrite.createBatchWriterFactory(null).createWriter(partitionCounter++, RANDOM.nextLong()); + + int size = 10000 + RANDOM.nextInt(10000); + int batches = 3; // one batch per partition + + for (int j = 0; j < batches; j++) { + String partitionPath = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[j % 3]; + Dataset inputRows = getRandomRows(sqlContext, size, partitionPath, false); + writeRows(inputRows, writer); + if (totalInputRows == null) { + totalInputRows = inputRows; + } else { + totalInputRows = totalInputRows.union(inputRows); + } + } + + HoodieWriterCommitMessage commitMetadata = (HoodieWriterCommitMessage) writer.commit(); + commitMessages.add(commitMetadata); + dataSourceInternalBatchWrite.commit(commitMessages.toArray(new HoodieWriterCommitMessage[0])); + metaClient.reloadActiveTimeline(); + + Dataset result = HoodieClientTestUtils.readCommit(basePath, sqlContext, metaClient.getCommitTimeline(), instantTime, + populateMetaFields); + + // verify output + assertOutput(totalInputRows, result, instantTime, Option.empty(), populateMetaFields); + assertWriteStatuses(commitMessages.get(0).getWriteStatuses(), batches, size, Option.empty(), Option.empty()); + } + } + + /** + * Tests that DataSourceWriter.abort() will abort the written records of interest write and commit batch1 write and abort batch2 Read of entire dataset should show only records from batch1. + * commit batch1 + * abort batch2 + * verify only records from batch1 is available to read + */ + @ParameterizedTest + @MethodSource("bulkInsertTypeParams") + public void testAbort(boolean populateMetaFields) throws Exception { + // init config and table + HoodieWriteConfig cfg = getWriteConfig(populateMetaFields); + HoodieTable table = HoodieSparkTable.create(cfg, context, metaClient); + String instantTime0 = "00" + 0; + // init writer + HoodieDataSourceInternalBatchWrite dataSourceInternalBatchWrite = + new HoodieDataSourceInternalBatchWrite(instantTime0, cfg, STRUCT_TYPE, sqlContext.sparkSession(), hadoopConf, Collections.emptyMap(), populateMetaFields, false); + DataWriter writer = dataSourceInternalBatchWrite.createBatchWriterFactory(null).createWriter(0, RANDOM.nextLong()); + + List partitionPaths = Arrays.asList(HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS); + List partitionPathsAbs = new ArrayList<>(); + for (String partitionPath : partitionPaths) { + partitionPathsAbs.add(basePath + "/" + partitionPath + "/*"); + } + + int size = 10 + RANDOM.nextInt(100); + int batches = 1; + Dataset totalInputRows = null; + + for (int j = 0; j < batches; j++) { + String partitionPath = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[j % 3]; + Dataset inputRows = getRandomRows(sqlContext, size, partitionPath, false); + writeRows(inputRows, writer); + if (totalInputRows == null) { + totalInputRows = inputRows; + } else { + totalInputRows = totalInputRows.union(inputRows); + } + } + + HoodieWriterCommitMessage commitMetadata = (HoodieWriterCommitMessage) writer.commit(); + List commitMessages = new ArrayList<>(); + commitMessages.add(commitMetadata); + // commit 1st batch + dataSourceInternalBatchWrite.commit(commitMessages.toArray(new HoodieWriterCommitMessage[0])); + metaClient.reloadActiveTimeline(); + Dataset result = HoodieClientTestUtils.read(jsc, basePath, sqlContext, metaClient.getFs(), partitionPathsAbs.toArray(new String[0])); + // verify rows + assertOutput(totalInputRows, result, instantTime0, Option.empty(), populateMetaFields); + assertWriteStatuses(commitMessages.get(0).getWriteStatuses(), batches, size, Option.empty(), Option.empty()); + + // 2nd batch. abort in the end + String instantTime1 = "00" + 1; + dataSourceInternalBatchWrite = + new HoodieDataSourceInternalBatchWrite(instantTime1, cfg, STRUCT_TYPE, sqlContext.sparkSession(), hadoopConf, Collections.emptyMap(), populateMetaFields, false); + writer = dataSourceInternalBatchWrite.createBatchWriterFactory(null).createWriter(1, RANDOM.nextLong()); + + for (int j = 0; j < batches; j++) { + String partitionPath = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[j % 3]; + Dataset inputRows = getRandomRows(sqlContext, size, partitionPath, false); + writeRows(inputRows, writer); + } + + commitMetadata = (HoodieWriterCommitMessage) writer.commit(); + commitMessages = new ArrayList<>(); + commitMessages.add(commitMetadata); + // commit 1st batch + dataSourceInternalBatchWrite.abort(commitMessages.toArray(new HoodieWriterCommitMessage[0])); + metaClient.reloadActiveTimeline(); + result = HoodieClientTestUtils.read(jsc, basePath, sqlContext, metaClient.getFs(), partitionPathsAbs.toArray(new String[0])); + // verify rows + // only rows from first batch should be present + assertOutput(totalInputRows, result, instantTime0, Option.empty(), populateMetaFields); + } + + private void writeRows(Dataset inputRows, DataWriter writer) throws Exception { + List internalRows = toInternalRows(inputRows, ENCODER); + // issue writes + for (InternalRow internalRow : internalRows) { + writer.write(internalRow); + } + } +} diff --git a/hudi-spark-datasource/hudi-spark3.2plus-common/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSpark32PlusAnalysis.scala b/hudi-spark-datasource/hudi-spark3.2plus-common/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSpark32PlusAnalysis.scala index 1c08717455aae..25aca8aaaa69f 100644 --- a/hudi-spark-datasource/hudi-spark3.2plus-common/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSpark32PlusAnalysis.scala +++ b/hudi-spark-datasource/hudi-spark3.2plus-common/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSpark32PlusAnalysis.scala @@ -48,34 +48,6 @@ import org.apache.spark.sql.{AnalysisException, SQLContext, SparkSession} * * Check out HUDI-4178 for more details */ -case class HoodieDataSourceV2ToV1Fallback(sparkSession: SparkSession) extends Rule[LogicalPlan] - with ProvidesHoodieConfig { - - override def apply(plan: LogicalPlan): LogicalPlan = plan match { - // The only place we're avoiding fallback is in [[AlterTableCommand]]s since - // current implementation relies on DSv2 features - case _: AlterTableCommand => plan - - // NOTE: Unfortunately, [[InsertIntoStatement]] is implemented in a way that doesn't expose - // target relation as a child (even though there's no good reason for that) - case iis@InsertIntoStatement(rv2@DataSourceV2Relation(v2Table: HoodieInternalV2Table, _, _, _, _), _, _, _, _, _) => - iis.copy(table = convertToV1(rv2, v2Table)) - - case _ => - plan.resolveOperatorsDown { - case rv2@DataSourceV2Relation(v2Table: HoodieInternalV2Table, _, _, _, _) => convertToV1(rv2, v2Table) - } - } - - private def convertToV1(rv2: DataSourceV2Relation, v2Table: HoodieInternalV2Table) = { - val output = rv2.output - val catalogTable = v2Table.catalogTable.map(_ => v2Table.v1Table) - val relation = new DefaultSource().createRelation(new SQLContext(sparkSession), - buildHoodieConfig(v2Table.hoodieCatalogTable), v2Table.hoodieCatalogTable.tableSchema) - - LogicalRelation(relation, output, catalogTable, isStreaming = false) - } -} /** * Rule for resolve hoodie's extended syntax or rewrite some logical plan. diff --git a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/HoodieSpark33CatalystExpressionUtils.scala b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/HoodieSpark33CatalystExpressionUtils.scala index 3ba5ed3d99910..29c2ac57da01b 100644 --- a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/HoodieSpark33CatalystExpressionUtils.scala +++ b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/HoodieSpark33CatalystExpressionUtils.scala @@ -17,13 +17,18 @@ package org.apache.spark.sql -import HoodieSparkTypeUtils.isCastPreservingOrdering +import org.apache.spark.sql.HoodieSparkTypeUtils.isCastPreservingOrdering +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.{Add, AnsiCast, Attribute, AttributeReference, AttributeSet, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, Exp, Expm1, Expression, FromUTCTimestamp, FromUnixTime, Log, Log10, Log1p, Log2, Lower, Multiply, ParseToDate, ParseToTimestamp, PredicateHelper, ShiftLeft, ShiftRight, ToUTCTimestamp, ToUnixTimestamp, Upper} import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{DataType, StructType} object HoodieSpark33CatalystExpressionUtils extends HoodieSpark3CatalystExpressionUtils with PredicateHelper { + override def getEncoder(schema: StructType): ExpressionEncoder[Row] = { + RowEncoder.apply(schema).resolveAndBind() + } + override def normalizeExprs(exprs: Seq[Expression], attributes: Seq[Attribute]): Seq[Expression] = DataSourceStrategy.normalizeExprs(exprs, attributes) diff --git a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/HoodieSpark33SchemaUtils.scala b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/HoodieSpark33SchemaUtils.scala index 37563a61ca64a..f31dadd0c3174 100644 --- a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/HoodieSpark33SchemaUtils.scala +++ b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/HoodieSpark33SchemaUtils.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.SchemaUtils /** @@ -30,4 +32,8 @@ object HoodieSpark33SchemaUtils extends HoodieSchemaUtils { caseSensitiveAnalysis: Boolean): Unit = { SchemaUtils.checkColumnNameDuplication(columnNames, colType, caseSensitiveAnalysis) } + + override def toAttributes(struct: StructType): Seq[Attribute] = { + struct.toAttributes + } } diff --git a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/execution/datasources/HoodieSpark33PartitionedFileUtils.scala b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/execution/datasources/HoodieSpark33PartitionedFileUtils.scala index 39e9c8efe3477..220825a6875da 100644 --- a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/execution/datasources/HoodieSpark33PartitionedFileUtils.scala +++ b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/execution/datasources/HoodieSpark33PartitionedFileUtils.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.execution.datasources -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.spark.sql.catalyst.InternalRow /** - * Utils on Spark [[PartitionedFile]] for Spark 3.3. + * Utils on Spark [[PartitionedFile]] and [[PartitionDirectory]] for Spark 3.3. */ object HoodieSpark33PartitionedFileUtils extends HoodieSparkPartitionedFileUtils { override def getPathFromPartitionedFile(partitionedFile: PartitionedFile): Path = { @@ -40,4 +40,12 @@ object HoodieSpark33PartitionedFileUtils extends HoodieSparkPartitionedFileUtils length: Long): PartitionedFile = { PartitionedFile(partitionValues, filePath.toUri.toString, start, length) } + + override def toFileStatuses(partitionDirs: Seq[PartitionDirectory]): Seq[FileStatus] = { + partitionDirs.flatMap(_.files) + } + + override def newPartitionDirectory(internalRow: InternalRow, statuses: Seq[FileStatus]): PartitionDirectory = { + PartitionDirectory(internalRow, statuses) + } } diff --git a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark33DataSourceUtils.scala b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark33DataSourceUtils.scala new file mode 100644 index 0000000000000..2aa85660eb511 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark33DataSourceUtils.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.sql.SPARK_VERSION_METADATA_KEY +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy +import org.apache.spark.util.Utils + +object Spark33DataSourceUtils { + + /** + * NOTE: This method was copied from Spark 3.2.0, and is required to maintain runtime + * compatibility against Spark 3.2.0 + */ + // scalastyle:off + def int96RebaseMode(lookupFileMeta: String => String, + modeByConfig: String): LegacyBehaviorPolicy.Value = { + if (Utils.isTesting && SQLConf.get.getConfString("spark.test.forceNoRebase", "") == "true") { + return LegacyBehaviorPolicy.CORRECTED + } + // If there is no version, we return the mode specified by the config. + Option(lookupFileMeta(SPARK_VERSION_METADATA_KEY)).map { version => + // Files written by Spark 3.0 and earlier follow the legacy hybrid calendar and we need to + // rebase the INT96 timestamp values. + // Files written by Spark 3.1 and latter may also need the rebase if they were written with + // the "LEGACY" rebase mode. + if (version < "3.1.0" || lookupFileMeta("org.apache.spark.legacyINT96") != null) { + LegacyBehaviorPolicy.LEGACY + } else { + LegacyBehaviorPolicy.CORRECTED + } + }.getOrElse(LegacyBehaviorPolicy.withName(modeByConfig)) + } + // scalastyle:on + + /** + * NOTE: This method was copied from Spark 3.2.0, and is required to maintain runtime + * compatibility against Spark 3.2.0 + */ + // scalastyle:off + def datetimeRebaseMode(lookupFileMeta: String => String, + modeByConfig: String): LegacyBehaviorPolicy.Value = { + if (Utils.isTesting && SQLConf.get.getConfString("spark.test.forceNoRebase", "") == "true") { + return LegacyBehaviorPolicy.CORRECTED + } + // If there is no version, we return the mode specified by the config. + Option(lookupFileMeta(SPARK_VERSION_METADATA_KEY)).map { version => + // Files written by Spark 2.4 and earlier follow the legacy hybrid calendar and we need to + // rebase the datetime values. + // Files written by Spark 3.0 and latter may also need the rebase if they were written with + // the "LEGACY" rebase mode. + if (version < "3.0.0" || lookupFileMeta("org.apache.spark.legacyDateTime") != null) { + LegacyBehaviorPolicy.LEGACY + } else { + LegacyBehaviorPolicy.CORRECTED + } + }.getOrElse(LegacyBehaviorPolicy.withName(modeByConfig)) + } + // scalastyle:on + +} diff --git a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark33LegacyHoodieParquetFileFormat.scala b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark33LegacyHoodieParquetFileFormat.scala index de6cbff90ca54..3b53b753b69d2 100644 --- a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark33LegacyHoodieParquetFileFormat.scala +++ b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark33LegacyHoodieParquetFileFormat.scala @@ -187,7 +187,7 @@ class Spark33LegacyHoodieParquetFileFormat(private val shouldAppendPartitionValu } else { // Spark 3.2.0 val datetimeRebaseMode = - Spark32PlusDataSourceUtils.datetimeRebaseMode(footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead) + Spark33DataSourceUtils.datetimeRebaseMode(footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead) createParquetFilters( parquetSchema, pushDownDate, @@ -287,9 +287,9 @@ class Spark33LegacyHoodieParquetFileFormat(private val shouldAppendPartitionValu } else { // Spark 3.2.0 val datetimeRebaseMode = - Spark32PlusDataSourceUtils.datetimeRebaseMode(footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead) + Spark33DataSourceUtils.datetimeRebaseMode(footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead) val int96RebaseMode = - Spark32PlusDataSourceUtils.int96RebaseMode(footerFileMetaData.getKeyValueMetaData.get, int96RebaseModeInRead) + Spark33DataSourceUtils.int96RebaseMode(footerFileMetaData.getKeyValueMetaData.get, int96RebaseModeInRead) createVectorizedParquetRecordReader( convertTz.orNull, datetimeRebaseMode.toString, @@ -349,9 +349,9 @@ class Spark33LegacyHoodieParquetFileFormat(private val shouldAppendPartitionValu int96RebaseSpec) } else { val datetimeRebaseMode = - Spark32PlusDataSourceUtils.datetimeRebaseMode(footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead) + Spark33DataSourceUtils.datetimeRebaseMode(footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead) val int96RebaseMode = - Spark32PlusDataSourceUtils.int96RebaseMode(footerFileMetaData.getKeyValueMetaData.get, int96RebaseModeInRead) + Spark33DataSourceUtils.int96RebaseMode(footerFileMetaData.getKeyValueMetaData.get, int96RebaseModeInRead) createParquetReadSupport( convertTz, /* enableVectorizedReader = */ false, diff --git a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSpark33Analysis.scala b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSpark33Analysis.scala new file mode 100644 index 0000000000000..3273d23e7c897 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSpark33Analysis.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hudi.analysis + +import org.apache.hudi.DefaultSource + +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.hudi.ProvidesHoodieConfig +import org.apache.spark.sql.hudi.catalog.HoodieInternalV2Table +import org.apache.spark.sql.{SQLContext, SparkSession} + +/** + * NOTE: PLEASE READ CAREFULLY + * + * Since Hudi relations don't currently implement DS V2 Read API, we have to fallback to V1 here. + * Such fallback will have considerable performance impact, therefore it's only performed in cases + * where V2 API have to be used. Currently only such use-case is using of Schema Evolution feature + * + * Check out HUDI-4178 for more details + */ +case class HoodieSpark33DataSourceV2ToV1Fallback(sparkSession: SparkSession) extends Rule[LogicalPlan] + with ProvidesHoodieConfig { + + override def apply(plan: LogicalPlan): LogicalPlan = plan match { + // The only place we're avoiding fallback is in [[AlterTableCommand]]s since + // current implementation relies on DSv2 features + case _: AlterTableCommand => plan + + // NOTE: Unfortunately, [[InsertIntoStatement]] is implemented in a way that doesn't expose + // target relation as a child (even though there's no good reason for that) + case iis@InsertIntoStatement(rv2@DataSourceV2Relation(v2Table: HoodieInternalV2Table, _, _, _, _), _, _, _, _, _) => + iis.copy(table = convertToV1(rv2, v2Table)) + + case _ => + plan.resolveOperatorsDown { + case rv2@DataSourceV2Relation(v2Table: HoodieInternalV2Table, _, _, _, _) => convertToV1(rv2, v2Table) + } + } + + private def convertToV1(rv2: DataSourceV2Relation, v2Table: HoodieInternalV2Table) = { + val output = rv2.output + val catalogTable = v2Table.catalogTable.map(_ => v2Table.v1Table) + val relation = new DefaultSource().createRelation(new SQLContext(sparkSession), + buildHoodieConfig(v2Table.hoodieCatalogTable), v2Table.hoodieCatalogTable.tableSchema) + + LogicalRelation(relation, output, catalogTable, isStreaming = false) + } +} diff --git a/hudi-spark-datasource/hudi-spark3.3.x/src/test/java/org/apache/hudi/internal/HoodieBulkInsertInternalWriterTestBase.java b/hudi-spark-datasource/hudi-spark3.3.x/src/test/java/org/apache/hudi/internal/HoodieBulkInsertInternalWriterTestBase.java new file mode 100644 index 0000000000000..d4b0b0e764ed8 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.3.x/src/test/java/org/apache/hudi/internal/HoodieBulkInsertInternalWriterTestBase.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hudi.internal; + +import org.apache.hudi.DataSourceWriteOptions; +import org.apache.hudi.client.WriteStatus; +import org.apache.hudi.common.model.HoodieRecord; +import org.apache.hudi.common.model.HoodieRecord.HoodieMetadataField; +import org.apache.hudi.common.model.HoodieWriteStat; +import org.apache.hudi.common.table.HoodieTableConfig; +import org.apache.hudi.common.testutils.HoodieTestDataGenerator; +import org.apache.hudi.common.util.Option; +import org.apache.hudi.config.HoodieWriteConfig; +import org.apache.hudi.testutils.HoodieSparkClientTestHarness; +import org.apache.hudi.testutils.SparkDatasetTestUtils; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Base class for TestHoodieBulkInsertDataInternalWriter. + */ +public class HoodieBulkInsertInternalWriterTestBase extends HoodieSparkClientTestHarness { + + protected static final Random RANDOM = new Random(); + + @BeforeEach + public void setUp() throws Exception { + initSparkContexts(); + initPath(); + initFileSystem(); + initTestDataGenerator(); + initMetaClient(); + initTimelineService(); + } + + @AfterEach + public void tearDown() throws Exception { + cleanupResources(); + } + + protected HoodieWriteConfig getWriteConfig(boolean populateMetaFields) { + return getWriteConfig(populateMetaFields, DataSourceWriteOptions.HIVE_STYLE_PARTITIONING().defaultValue()); + } + + protected HoodieWriteConfig getWriteConfig(boolean populateMetaFields, String hiveStylePartitioningValue) { + Properties properties = new Properties(); + if (!populateMetaFields) { + properties.setProperty(DataSourceWriteOptions.RECORDKEY_FIELD().key(), SparkDatasetTestUtils.RECORD_KEY_FIELD_NAME); + properties.setProperty(DataSourceWriteOptions.PARTITIONPATH_FIELD().key(), SparkDatasetTestUtils.PARTITION_PATH_FIELD_NAME); + properties.setProperty(HoodieTableConfig.POPULATE_META_FIELDS.key(), "false"); + } + properties.setProperty(DataSourceWriteOptions.HIVE_STYLE_PARTITIONING().key(), hiveStylePartitioningValue); + return SparkDatasetTestUtils.getConfigBuilder(basePath, timelineServicePort).withProperties(properties).build(); + } + + protected void assertWriteStatuses(List writeStatuses, int batches, int size, + Option> fileAbsPaths, Option> fileNames) { + assertWriteStatuses(writeStatuses, batches, size, false, fileAbsPaths, fileNames, false); + } + + protected void assertWriteStatuses(List writeStatuses, int batches, int size, boolean areRecordsSorted, + Option> fileAbsPaths, Option> fileNames, boolean isHiveStylePartitioning) { + if (areRecordsSorted) { + assertEquals(batches, writeStatuses.size()); + } else { + assertEquals(Math.min(HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS.length, batches), writeStatuses.size()); + } + + Map sizeMap = new HashMap<>(); + if (!areRecordsSorted) { + // no of records are written per batch. Every 4th batch goes into same writeStatus. So, populating the size expected + // per write status + for (int i = 0; i < batches; i++) { + String partitionPath = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[i % 3]; + if (!sizeMap.containsKey(partitionPath)) { + sizeMap.put(partitionPath, 0L); + } + sizeMap.put(partitionPath, sizeMap.get(partitionPath) + size); + } + } + + int counter = 0; + for (WriteStatus writeStatus : writeStatuses) { + // verify write status + String actualPartitionPathFormat = isHiveStylePartitioning ? SparkDatasetTestUtils.PARTITION_PATH_FIELD_NAME + "=%s" : "%s"; + assertEquals(String.format(actualPartitionPathFormat, HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[counter % 3]), writeStatus.getPartitionPath()); + if (areRecordsSorted) { + assertEquals(writeStatus.getTotalRecords(), size); + } else { + assertEquals(writeStatus.getTotalRecords(), sizeMap.get(HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[counter % 3])); + } + assertNull(writeStatus.getGlobalError()); + assertEquals(writeStatus.getTotalErrorRecords(), 0); + assertEquals(writeStatus.getTotalErrorRecords(), 0); + assertFalse(writeStatus.hasErrors()); + assertNotNull(writeStatus.getFileId()); + String fileId = writeStatus.getFileId(); + if (fileAbsPaths.isPresent()) { + fileAbsPaths.get().add(basePath + "/" + writeStatus.getStat().getPath()); + } + if (fileNames.isPresent()) { + fileNames.get().add(writeStatus.getStat().getPath() + .substring(writeStatus.getStat().getPath().lastIndexOf('/') + 1)); + } + HoodieWriteStat writeStat = writeStatus.getStat(); + if (areRecordsSorted) { + assertEquals(size, writeStat.getNumInserts()); + assertEquals(size, writeStat.getNumWrites()); + } else { + assertEquals(sizeMap.get(HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[counter % 3]), writeStat.getNumInserts()); + assertEquals(sizeMap.get(HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[counter % 3]), writeStat.getNumWrites()); + } + assertEquals(fileId, writeStat.getFileId()); + assertEquals(String.format(actualPartitionPathFormat, HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[counter++ % 3]), writeStat.getPartitionPath()); + assertEquals(0, writeStat.getNumDeletes()); + assertEquals(0, writeStat.getNumUpdateWrites()); + assertEquals(0, writeStat.getTotalWriteErrors()); + } + } + + protected void assertOutput(Dataset expectedRows, Dataset actualRows, String instantTime, Option> fileNames, + boolean populateMetaColumns) { + if (populateMetaColumns) { + // verify 3 meta fields that are filled in within create handle + actualRows.collectAsList().forEach(entry -> { + assertEquals(entry.get(HoodieMetadataField.COMMIT_TIME_METADATA_FIELD.ordinal()).toString(), instantTime); + assertFalse(entry.isNullAt(HoodieMetadataField.FILENAME_METADATA_FIELD.ordinal())); + if (fileNames.isPresent()) { + assertTrue(fileNames.get().contains(entry.get(HoodieMetadataField.FILENAME_METADATA_FIELD.ordinal()))); + } + assertFalse(entry.isNullAt(HoodieMetadataField.COMMIT_SEQNO_METADATA_FIELD.ordinal())); + }); + + // after trimming 2 of the meta fields, rest of the fields should match + Dataset trimmedExpected = expectedRows.drop(HoodieRecord.COMMIT_SEQNO_METADATA_FIELD, HoodieRecord.COMMIT_TIME_METADATA_FIELD, HoodieRecord.FILENAME_METADATA_FIELD); + Dataset trimmedActual = actualRows.drop(HoodieRecord.COMMIT_SEQNO_METADATA_FIELD, HoodieRecord.COMMIT_TIME_METADATA_FIELD, HoodieRecord.FILENAME_METADATA_FIELD); + assertEquals(0, trimmedActual.except(trimmedExpected).count()); + } else { // operation = BULK_INSERT_APPEND_ONLY + // all meta columns are untouched + assertEquals(0, expectedRows.except(actualRows).count()); + } + } +} diff --git a/hudi-spark-datasource/hudi-spark3.3.x/src/test/java/org/apache/hudi/spark3/internal/TestReflectUtil.java b/hudi-spark-datasource/hudi-spark3.3.x/src/test/java/org/apache/hudi/spark3/internal/TestReflectUtil.java index 0d1867047847b..0763a22f032c0 100644 --- a/hudi-spark-datasource/hudi-spark3.3.x/src/test/java/org/apache/hudi/spark3/internal/TestReflectUtil.java +++ b/hudi-spark-datasource/hudi-spark3.3.x/src/test/java/org/apache/hudi/spark3/internal/TestReflectUtil.java @@ -45,7 +45,8 @@ public void testDataSourceWriterExtraCommitMetadata() throws Exception { scala.collection.immutable.List.empty(), statement.query(), statement.overwrite(), - statement.ifPartitionNotExists()); + statement.ifPartitionNotExists(), + false); Assertions.assertTrue( ((UnresolvedRelation)newStatment.table()).multipartIdentifier().contains("test_reflect_util")); diff --git a/hudi-spark-datasource/hudi-spark3.4.x/src/main/scala/org/apache/spark/sql/HoodieSpark34CatalystExpressionUtils.scala b/hudi-spark-datasource/hudi-spark3.4.x/src/main/scala/org/apache/spark/sql/HoodieSpark34CatalystExpressionUtils.scala index e93228a47ee5a..c36ca1ed55b4c 100644 --- a/hudi-spark-datasource/hudi-spark3.4.x/src/main/scala/org/apache/spark/sql/HoodieSpark34CatalystExpressionUtils.scala +++ b/hudi-spark-datasource/hudi-spark3.4.x/src/main/scala/org/apache/spark/sql/HoodieSpark34CatalystExpressionUtils.scala @@ -18,12 +18,17 @@ package org.apache.spark.sql import org.apache.spark.sql.HoodieSparkTypeUtils.isCastPreservingOrdering +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions.{Add, Attribute, AttributeReference, AttributeSet, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, EvalMode, Exp, Expm1, Expression, FromUTCTimestamp, FromUnixTime, Log, Log10, Log1p, Log2, Lower, Multiply, ParseToDate, ParseToTimestamp, PredicateHelper, ShiftLeft, ShiftRight, ToUTCTimestamp, ToUnixTimestamp, Upper} import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{DataType, StructType} object HoodieSpark34CatalystExpressionUtils extends HoodieSpark3CatalystExpressionUtils with PredicateHelper { + override def getEncoder(schema: StructType): ExpressionEncoder[Row] = { + RowEncoder.apply(schema).resolveAndBind() + } + override def normalizeExprs(exprs: Seq[Expression], attributes: Seq[Attribute]): Seq[Expression] = { DataSourceStrategy.normalizeExprs(exprs, attributes) } diff --git a/hudi-spark-datasource/hudi-spark3.4.x/src/main/scala/org/apache/spark/sql/HoodieSpark34SchemaUtils.scala b/hudi-spark-datasource/hudi-spark3.4.x/src/main/scala/org/apache/spark/sql/HoodieSpark34SchemaUtils.scala index d597544d26312..d6cf4a3fad078 100644 --- a/hudi-spark-datasource/hudi-spark3.4.x/src/main/scala/org/apache/spark/sql/HoodieSpark34SchemaUtils.scala +++ b/hudi-spark-datasource/hudi-spark3.4.x/src/main/scala/org/apache/spark/sql/HoodieSpark34SchemaUtils.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.SchemaUtils /** @@ -30,4 +32,8 @@ object HoodieSpark34SchemaUtils extends HoodieSchemaUtils { caseSensitiveAnalysis: Boolean): Unit = { SchemaUtils.checkColumnNameDuplication(columnNames, caseSensitiveAnalysis) } + + override def toAttributes(struct: StructType): Seq[Attribute] = { + struct.toAttributes + } } diff --git a/hudi-spark-datasource/hudi-spark3.4.x/src/main/scala/org/apache/spark/sql/execution/datasources/HoodieSpark34PartitionedFileUtils.scala b/hudi-spark-datasource/hudi-spark3.4.x/src/main/scala/org/apache/spark/sql/execution/datasources/HoodieSpark34PartitionedFileUtils.scala index 249d7e59051df..cfbf22246c5f9 100644 --- a/hudi-spark-datasource/hudi-spark3.4.x/src/main/scala/org/apache/spark/sql/execution/datasources/HoodieSpark34PartitionedFileUtils.scala +++ b/hudi-spark-datasource/hudi-spark3.4.x/src/main/scala/org/apache/spark/sql/execution/datasources/HoodieSpark34PartitionedFileUtils.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.execution.datasources -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.spark.paths.SparkPath import org.apache.spark.sql.catalyst.InternalRow /** - * Utils on Spark [[PartitionedFile]] for Spark 3.4. + * Utils on Spark [[PartitionedFile]] and [[PartitionDirectory]] for Spark 3.4. */ object HoodieSpark34PartitionedFileUtils extends HoodieSparkPartitionedFileUtils { override def getPathFromPartitionedFile(partitionedFile: PartitionedFile): Path = { @@ -41,4 +41,12 @@ object HoodieSpark34PartitionedFileUtils extends HoodieSparkPartitionedFileUtils length: Long): PartitionedFile = { PartitionedFile(partitionValues, SparkPath.fromPath(filePath), start, length) } + + override def toFileStatuses(partitionDirs: Seq[PartitionDirectory]): Seq[FileStatus] = { + partitionDirs.flatMap(_.files) + } + + override def newPartitionDirectory(internalRow: InternalRow, statuses: Seq[FileStatus]): PartitionDirectory = { + PartitionDirectory(internalRow, statuses) + } } diff --git a/hudi-spark-datasource/hudi-spark3.4.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark34DataSourceUtils.scala b/hudi-spark-datasource/hudi-spark3.4.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark34DataSourceUtils.scala new file mode 100644 index 0000000000000..d404bc8c24b53 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.4.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark34DataSourceUtils.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.sql.SPARK_VERSION_METADATA_KEY +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy +import org.apache.spark.util.Utils + +object Spark34DataSourceUtils { + + /** + * NOTE: This method was copied from Spark 3.2.0, and is required to maintain runtime + * compatibility against Spark 3.2.0 + */ + // scalastyle:off + def int96RebaseMode(lookupFileMeta: String => String, + modeByConfig: String): LegacyBehaviorPolicy.Value = { + if (Utils.isTesting && SQLConf.get.getConfString("spark.test.forceNoRebase", "") == "true") { + return LegacyBehaviorPolicy.CORRECTED + } + // If there is no version, we return the mode specified by the config. + Option(lookupFileMeta(SPARK_VERSION_METADATA_KEY)).map { version => + // Files written by Spark 3.0 and earlier follow the legacy hybrid calendar and we need to + // rebase the INT96 timestamp values. + // Files written by Spark 3.1 and latter may also need the rebase if they were written with + // the "LEGACY" rebase mode. + if (version < "3.1.0" || lookupFileMeta("org.apache.spark.legacyINT96") != null) { + LegacyBehaviorPolicy.LEGACY + } else { + LegacyBehaviorPolicy.CORRECTED + } + }.getOrElse(LegacyBehaviorPolicy.withName(modeByConfig)) + } + // scalastyle:on + + /** + * NOTE: This method was copied from Spark 3.2.0, and is required to maintain runtime + * compatibility against Spark 3.2.0 + */ + // scalastyle:off + def datetimeRebaseMode(lookupFileMeta: String => String, + modeByConfig: String): LegacyBehaviorPolicy.Value = { + if (Utils.isTesting && SQLConf.get.getConfString("spark.test.forceNoRebase", "") == "true") { + return LegacyBehaviorPolicy.CORRECTED + } + // If there is no version, we return the mode specified by the config. + Option(lookupFileMeta(SPARK_VERSION_METADATA_KEY)).map { version => + // Files written by Spark 2.4 and earlier follow the legacy hybrid calendar and we need to + // rebase the datetime values. + // Files written by Spark 3.0 and latter may also need the rebase if they were written with + // the "LEGACY" rebase mode. + if (version < "3.0.0" || lookupFileMeta("org.apache.spark.legacyDateTime") != null) { + LegacyBehaviorPolicy.LEGACY + } else { + LegacyBehaviorPolicy.CORRECTED + } + }.getOrElse(LegacyBehaviorPolicy.withName(modeByConfig)) + } + // scalastyle:on + +} diff --git a/hudi-spark-datasource/hudi-spark3.4.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark34LegacyHoodieParquetFileFormat.scala b/hudi-spark-datasource/hudi-spark3.4.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark34LegacyHoodieParquetFileFormat.scala index 6de8ded06ec00..cd76ce6f3b2e1 100644 --- a/hudi-spark-datasource/hudi-spark3.4.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark34LegacyHoodieParquetFileFormat.scala +++ b/hudi-spark-datasource/hudi-spark3.4.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark34LegacyHoodieParquetFileFormat.scala @@ -203,7 +203,7 @@ class Spark34LegacyHoodieParquetFileFormat(private val shouldAppendPartitionValu } else { // Spark 3.2.0 val datetimeRebaseMode = - Spark32PlusDataSourceUtils.datetimeRebaseMode(footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead) + Spark34DataSourceUtils.datetimeRebaseMode(footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead) createParquetFilters( parquetSchema, pushDownDate, @@ -303,9 +303,9 @@ class Spark34LegacyHoodieParquetFileFormat(private val shouldAppendPartitionValu } else { // Spark 3.2.0 val datetimeRebaseMode = - Spark32PlusDataSourceUtils.datetimeRebaseMode(footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead) + Spark34DataSourceUtils.datetimeRebaseMode(footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead) val int96RebaseMode = - Spark32PlusDataSourceUtils.int96RebaseMode(footerFileMetaData.getKeyValueMetaData.get, int96RebaseModeInRead) + Spark34DataSourceUtils.int96RebaseMode(footerFileMetaData.getKeyValueMetaData.get, int96RebaseModeInRead) createVectorizedParquetRecordReader( convertTz.orNull, datetimeRebaseMode.toString, @@ -365,9 +365,9 @@ class Spark34LegacyHoodieParquetFileFormat(private val shouldAppendPartitionValu int96RebaseSpec) } else { val datetimeRebaseMode = - Spark32PlusDataSourceUtils.datetimeRebaseMode(footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead) + Spark34DataSourceUtils.datetimeRebaseMode(footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead) val int96RebaseMode = - Spark32PlusDataSourceUtils.int96RebaseMode(footerFileMetaData.getKeyValueMetaData.get, int96RebaseModeInRead) + Spark34DataSourceUtils.int96RebaseMode(footerFileMetaData.getKeyValueMetaData.get, int96RebaseModeInRead) createParquetReadSupport( convertTz, /* enableVectorizedReader = */ false, diff --git a/hudi-spark-datasource/hudi-spark3.4.x/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSpark34Analysis.scala b/hudi-spark-datasource/hudi-spark3.4.x/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSpark34Analysis.scala new file mode 100644 index 0000000000000..9194a667a8900 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.4.x/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSpark34Analysis.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hudi.analysis + +import org.apache.hudi.DefaultSource + +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.hudi.ProvidesHoodieConfig +import org.apache.spark.sql.hudi.catalog.HoodieInternalV2Table +import org.apache.spark.sql.{SQLContext, SparkSession} + +/** + * NOTE: PLEASE READ CAREFULLY + * + * Since Hudi relations don't currently implement DS V2 Read API, we have to fallback to V1 here. + * Such fallback will have considerable performance impact, therefore it's only performed in cases + * where V2 API have to be used. Currently only such use-case is using of Schema Evolution feature + * + * Check out HUDI-4178 for more details + */ +case class HoodieSpark34DataSourceV2ToV1Fallback(sparkSession: SparkSession) extends Rule[LogicalPlan] + with ProvidesHoodieConfig { + + override def apply(plan: LogicalPlan): LogicalPlan = plan match { + // The only place we're avoiding fallback is in [[AlterTableCommand]]s since + // current implementation relies on DSv2 features + case _: AlterTableCommand => plan + + // NOTE: Unfortunately, [[InsertIntoStatement]] is implemented in a way that doesn't expose + // target relation as a child (even though there's no good reason for that) + case iis@InsertIntoStatement(rv2@DataSourceV2Relation(v2Table: HoodieInternalV2Table, _, _, _, _), _, _, _, _, _) => + iis.copy(table = convertToV1(rv2, v2Table)) + + case _ => + plan.resolveOperatorsDown { + case rv2@DataSourceV2Relation(v2Table: HoodieInternalV2Table, _, _, _, _) => convertToV1(rv2, v2Table) + } + } + + private def convertToV1(rv2: DataSourceV2Relation, v2Table: HoodieInternalV2Table) = { + val output = rv2.output + val catalogTable = v2Table.catalogTable.map(_ => v2Table.v1Table) + val relation = new DefaultSource().createRelation(new SQLContext(sparkSession), + buildHoodieConfig(v2Table.hoodieCatalogTable), v2Table.hoodieCatalogTable.tableSchema) + + LogicalRelation(relation, output, catalogTable, isStreaming = false) + } +} diff --git a/hudi-spark-datasource/hudi-spark3.4.x/src/test/java/org/apache/hudi/internal/HoodieBulkInsertInternalWriterTestBase.java b/hudi-spark-datasource/hudi-spark3.4.x/src/test/java/org/apache/hudi/internal/HoodieBulkInsertInternalWriterTestBase.java new file mode 100644 index 0000000000000..d4b0b0e764ed8 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.4.x/src/test/java/org/apache/hudi/internal/HoodieBulkInsertInternalWriterTestBase.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hudi.internal; + +import org.apache.hudi.DataSourceWriteOptions; +import org.apache.hudi.client.WriteStatus; +import org.apache.hudi.common.model.HoodieRecord; +import org.apache.hudi.common.model.HoodieRecord.HoodieMetadataField; +import org.apache.hudi.common.model.HoodieWriteStat; +import org.apache.hudi.common.table.HoodieTableConfig; +import org.apache.hudi.common.testutils.HoodieTestDataGenerator; +import org.apache.hudi.common.util.Option; +import org.apache.hudi.config.HoodieWriteConfig; +import org.apache.hudi.testutils.HoodieSparkClientTestHarness; +import org.apache.hudi.testutils.SparkDatasetTestUtils; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Base class for TestHoodieBulkInsertDataInternalWriter. + */ +public class HoodieBulkInsertInternalWriterTestBase extends HoodieSparkClientTestHarness { + + protected static final Random RANDOM = new Random(); + + @BeforeEach + public void setUp() throws Exception { + initSparkContexts(); + initPath(); + initFileSystem(); + initTestDataGenerator(); + initMetaClient(); + initTimelineService(); + } + + @AfterEach + public void tearDown() throws Exception { + cleanupResources(); + } + + protected HoodieWriteConfig getWriteConfig(boolean populateMetaFields) { + return getWriteConfig(populateMetaFields, DataSourceWriteOptions.HIVE_STYLE_PARTITIONING().defaultValue()); + } + + protected HoodieWriteConfig getWriteConfig(boolean populateMetaFields, String hiveStylePartitioningValue) { + Properties properties = new Properties(); + if (!populateMetaFields) { + properties.setProperty(DataSourceWriteOptions.RECORDKEY_FIELD().key(), SparkDatasetTestUtils.RECORD_KEY_FIELD_NAME); + properties.setProperty(DataSourceWriteOptions.PARTITIONPATH_FIELD().key(), SparkDatasetTestUtils.PARTITION_PATH_FIELD_NAME); + properties.setProperty(HoodieTableConfig.POPULATE_META_FIELDS.key(), "false"); + } + properties.setProperty(DataSourceWriteOptions.HIVE_STYLE_PARTITIONING().key(), hiveStylePartitioningValue); + return SparkDatasetTestUtils.getConfigBuilder(basePath, timelineServicePort).withProperties(properties).build(); + } + + protected void assertWriteStatuses(List writeStatuses, int batches, int size, + Option> fileAbsPaths, Option> fileNames) { + assertWriteStatuses(writeStatuses, batches, size, false, fileAbsPaths, fileNames, false); + } + + protected void assertWriteStatuses(List writeStatuses, int batches, int size, boolean areRecordsSorted, + Option> fileAbsPaths, Option> fileNames, boolean isHiveStylePartitioning) { + if (areRecordsSorted) { + assertEquals(batches, writeStatuses.size()); + } else { + assertEquals(Math.min(HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS.length, batches), writeStatuses.size()); + } + + Map sizeMap = new HashMap<>(); + if (!areRecordsSorted) { + // no of records are written per batch. Every 4th batch goes into same writeStatus. So, populating the size expected + // per write status + for (int i = 0; i < batches; i++) { + String partitionPath = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[i % 3]; + if (!sizeMap.containsKey(partitionPath)) { + sizeMap.put(partitionPath, 0L); + } + sizeMap.put(partitionPath, sizeMap.get(partitionPath) + size); + } + } + + int counter = 0; + for (WriteStatus writeStatus : writeStatuses) { + // verify write status + String actualPartitionPathFormat = isHiveStylePartitioning ? SparkDatasetTestUtils.PARTITION_PATH_FIELD_NAME + "=%s" : "%s"; + assertEquals(String.format(actualPartitionPathFormat, HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[counter % 3]), writeStatus.getPartitionPath()); + if (areRecordsSorted) { + assertEquals(writeStatus.getTotalRecords(), size); + } else { + assertEquals(writeStatus.getTotalRecords(), sizeMap.get(HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[counter % 3])); + } + assertNull(writeStatus.getGlobalError()); + assertEquals(writeStatus.getTotalErrorRecords(), 0); + assertEquals(writeStatus.getTotalErrorRecords(), 0); + assertFalse(writeStatus.hasErrors()); + assertNotNull(writeStatus.getFileId()); + String fileId = writeStatus.getFileId(); + if (fileAbsPaths.isPresent()) { + fileAbsPaths.get().add(basePath + "/" + writeStatus.getStat().getPath()); + } + if (fileNames.isPresent()) { + fileNames.get().add(writeStatus.getStat().getPath() + .substring(writeStatus.getStat().getPath().lastIndexOf('/') + 1)); + } + HoodieWriteStat writeStat = writeStatus.getStat(); + if (areRecordsSorted) { + assertEquals(size, writeStat.getNumInserts()); + assertEquals(size, writeStat.getNumWrites()); + } else { + assertEquals(sizeMap.get(HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[counter % 3]), writeStat.getNumInserts()); + assertEquals(sizeMap.get(HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[counter % 3]), writeStat.getNumWrites()); + } + assertEquals(fileId, writeStat.getFileId()); + assertEquals(String.format(actualPartitionPathFormat, HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[counter++ % 3]), writeStat.getPartitionPath()); + assertEquals(0, writeStat.getNumDeletes()); + assertEquals(0, writeStat.getNumUpdateWrites()); + assertEquals(0, writeStat.getTotalWriteErrors()); + } + } + + protected void assertOutput(Dataset expectedRows, Dataset actualRows, String instantTime, Option> fileNames, + boolean populateMetaColumns) { + if (populateMetaColumns) { + // verify 3 meta fields that are filled in within create handle + actualRows.collectAsList().forEach(entry -> { + assertEquals(entry.get(HoodieMetadataField.COMMIT_TIME_METADATA_FIELD.ordinal()).toString(), instantTime); + assertFalse(entry.isNullAt(HoodieMetadataField.FILENAME_METADATA_FIELD.ordinal())); + if (fileNames.isPresent()) { + assertTrue(fileNames.get().contains(entry.get(HoodieMetadataField.FILENAME_METADATA_FIELD.ordinal()))); + } + assertFalse(entry.isNullAt(HoodieMetadataField.COMMIT_SEQNO_METADATA_FIELD.ordinal())); + }); + + // after trimming 2 of the meta fields, rest of the fields should match + Dataset trimmedExpected = expectedRows.drop(HoodieRecord.COMMIT_SEQNO_METADATA_FIELD, HoodieRecord.COMMIT_TIME_METADATA_FIELD, HoodieRecord.FILENAME_METADATA_FIELD); + Dataset trimmedActual = actualRows.drop(HoodieRecord.COMMIT_SEQNO_METADATA_FIELD, HoodieRecord.COMMIT_TIME_METADATA_FIELD, HoodieRecord.FILENAME_METADATA_FIELD); + assertEquals(0, trimmedActual.except(trimmedExpected).count()); + } else { // operation = BULK_INSERT_APPEND_ONLY + // all meta columns are untouched + assertEquals(0, expectedRows.except(actualRows).count()); + } + } +} diff --git a/hudi-spark-datasource/hudi-spark3.4.x/src/test/java/org/apache/hudi/spark3/internal/TestReflectUtil.java b/hudi-spark-datasource/hudi-spark3.4.x/src/test/java/org/apache/hudi/spark3/internal/TestReflectUtil.java index 0d1867047847b..0763a22f032c0 100644 --- a/hudi-spark-datasource/hudi-spark3.4.x/src/test/java/org/apache/hudi/spark3/internal/TestReflectUtil.java +++ b/hudi-spark-datasource/hudi-spark3.4.x/src/test/java/org/apache/hudi/spark3/internal/TestReflectUtil.java @@ -45,7 +45,8 @@ public void testDataSourceWriterExtraCommitMetadata() throws Exception { scala.collection.immutable.List.empty(), statement.query(), statement.overwrite(), - statement.ifPartitionNotExists()); + statement.ifPartitionNotExists(), + false); Assertions.assertTrue( ((UnresolvedRelation)newStatment.table()).multipartIdentifier().contains("test_reflect_util")); diff --git a/hudi-spark-datasource/hudi-spark3.5.x/pom.xml b/hudi-spark-datasource/hudi-spark3.5.x/pom.xml new file mode 100644 index 0000000000000..ba2b0cb9e34d5 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.5.x/pom.xml @@ -0,0 +1,342 @@ + + + + + hudi-spark-datasource + org.apache.hudi + 1.0.0-SNAPSHOT + + 4.0.0 + + hudi-spark3.5.x_2.12 + 1.0.0-SNAPSHOT + + hudi-spark3.5.x_2.12 + jar + + + ${project.parent.parent.basedir} + + + + + + src/main/resources + + + + + + net.alchim31.maven + scala-maven-plugin + ${scala-maven-plugin.version} + + + -nobootcp + + false + + + + org.apache.maven.plugins + maven-compiler-plugin + + + + + + + org.apache.maven.plugins + maven-dependency-plugin + + + copy-dependencies + prepare-package + + copy-dependencies + + + ${project.build.directory}/lib + true + true + true + + + + + + net.alchim31.maven + scala-maven-plugin + + + -nobootcp + -target:jvm-1.8 + + + + + scala-compile-first + process-resources + + add-source + compile + + + + scala-test-compile + process-test-resources + + testCompile + + + + + + org.apache.maven.plugins + maven-compiler-plugin + + + compile + + compile + + + + + + org.apache.maven.plugins + maven-jar-plugin + + + + test-jar + + test-compile + + + + false + + + + org.apache.maven.plugins + maven-surefire-plugin + + ${skip.hudi-spark3.unit.tests} + + + + org.apache.rat + apache-rat-plugin + + + org.scalastyle + scalastyle-maven-plugin + + + org.jacoco + jacoco-maven-plugin + + + org.antlr + antlr4-maven-plugin + ${antlr.version} + + + + antlr4 + + + + + true + true + ../hudi-spark3.5.x/src/main/antlr4 + ../hudi-spark3.5.x/src/main/antlr4/imports + + + + + + + + + org.apache.spark + spark-sql_${scala.binary.version} + ${spark35.version} + provided + true + + + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${spark35.version} + provided + true + + + + org.apache.spark + spark-core_${scala.binary.version} + ${spark35.version} + provided + true + + + * + * + + + + + + com.fasterxml.jackson.core + jackson-databind + ${fasterxml.spark3.version} + + + com.fasterxml.jackson.core + jackson-annotations + ${fasterxml.spark3.version} + + + com.fasterxml.jackson.core + jackson-core + ${fasterxml.spark3.version} + + + + org.apache.hudi + hudi-spark-client + ${project.version} + + + + org.apache.hudi + hudi-spark-common_${scala.binary.version} + ${project.version} + + + + org.json4s + json4s-jackson_${scala.binary.version} + 3.7.0-M11 + + + com.fasterxml.jackson.core + * + + + + + + + org.apache.hudi + hudi-spark3-common + ${project.version} + + + + + org.apache.hudi + hudi-spark3.2plus-common + ${project.version} + + + + + org.apache.hudi + hudi-tests-common + ${project.version} + test + + + + org.apache.hudi + hudi-client-common + ${project.version} + tests + test-jar + test + + + + org.apache.hudi + hudi-spark-client + ${project.version} + tests + test-jar + test + + + + org.apache.hudi + hudi-common + ${project.version} + tests + test-jar + test + + + + org.apache.hudi + hudi-spark-common_${scala.binary.version} + ${project.version} + tests + test-jar + test + + + + org.apache.spark + spark-core_${scala.binary.version} + ${spark3.version} + tests + test + + + + org.apache.parquet + parquet-avro + test + + + + org.apache.hadoop + hadoop-hdfs + tests + test + + + + org.mortbay.jetty + * + + + javax.servlet.jsp + * + + + javax.servlet + * + + + + + + + diff --git a/hudi-spark-datasource/hudi-spark3.5.x/src/main/antlr4/imports/SqlBase.g4 b/hudi-spark-datasource/hudi-spark3.5.x/src/main/antlr4/imports/SqlBase.g4 new file mode 100644 index 0000000000000..d7f87b4e5aa59 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.5.x/src/main/antlr4/imports/SqlBase.g4 @@ -0,0 +1,1940 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * This file is an adaptation of Presto's presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 grammar. + */ + +// The parser file is forked from spark 3.2.0's SqlBase.g4. +grammar SqlBase; + +@parser::members { + /** + * When false, INTERSECT is given the greater precedence over the other set + * operations (UNION, EXCEPT and MINUS) as per the SQL standard. + */ + public boolean legacy_setops_precedence_enabled = false; + + /** + * When false, a literal with an exponent would be converted into + * double type rather than decimal type. + */ + public boolean legacy_exponent_literal_as_decimal_enabled = false; + + /** + * When true, the behavior of keywords follows ANSI SQL standard. + */ + public boolean SQL_standard_keyword_behavior = false; +} + +@lexer::members { + /** + * Verify whether current token is a valid decimal token (which contains dot). + * Returns true if the character that follows the token is not a digit or letter or underscore. + * + * For example: + * For char stream "2.3", "2." is not a valid decimal token, because it is followed by digit '3'. + * For char stream "2.3_", "2.3" is not a valid decimal token, because it is followed by '_'. + * For char stream "2.3W", "2.3" is not a valid decimal token, because it is followed by 'W'. + * For char stream "12.0D 34.E2+0.12 " 12.0D is a valid decimal token because it is followed + * by a space. 34.E2 is a valid decimal token because it is followed by symbol '+' + * which is not a digit or letter or underscore. + */ + public boolean isValidDecimal() { + int nextChar = _input.LA(1); + if (nextChar >= 'A' && nextChar <= 'Z' || nextChar >= '0' && nextChar <= '9' || + nextChar == '_') { + return false; + } else { + return true; + } + } + + /** + * This method will be called when we see '/*' and try to match it as a bracketed comment. + * If the next character is '+', it should be parsed as hint later, and we cannot match + * it as a bracketed comment. + * + * Returns true if the next character is '+'. + */ + public boolean isHint() { + int nextChar = _input.LA(1); + if (nextChar == '+') { + return true; + } else { + return false; + } + } +} + +singleStatement + : statement ';'* EOF + ; + +singleExpression + : namedExpression EOF + ; + +singleTableIdentifier + : tableIdentifier EOF + ; + +singleMultipartIdentifier + : multipartIdentifier EOF + ; + +singleFunctionIdentifier + : functionIdentifier EOF + ; + +singleDataType + : dataType EOF + ; + +singleTableSchema + : colTypeList EOF + ; + +statement + : query #statementDefault + | ctes? dmlStatementNoWith #dmlStatement + | USE NAMESPACE? multipartIdentifier #use + | CREATE namespace (IF NOT EXISTS)? multipartIdentifier + (commentSpec | + locationSpec | + (WITH (DBPROPERTIES | PROPERTIES) tablePropertyList))* #createNamespace + | ALTER namespace multipartIdentifier + SET (DBPROPERTIES | PROPERTIES) tablePropertyList #setNamespaceProperties + | ALTER namespace multipartIdentifier + SET locationSpec #setNamespaceLocation + | DROP namespace (IF EXISTS)? multipartIdentifier + (RESTRICT | CASCADE)? #dropNamespace + | SHOW (DATABASES | NAMESPACES) ((FROM | IN) multipartIdentifier)? + (LIKE? pattern=STRING)? #showNamespaces + | createTableHeader ('(' colTypeList ')')? tableProvider? + createTableClauses + (AS? query)? #createTable + | CREATE TABLE (IF NOT EXISTS)? target=tableIdentifier + LIKE source=tableIdentifier + (tableProvider | + rowFormat | + createFileFormat | + locationSpec | + (TBLPROPERTIES tableProps=tablePropertyList))* #createTableLike + | replaceTableHeader ('(' colTypeList ')')? tableProvider? + createTableClauses + (AS? query)? #replaceTable + | ANALYZE TABLE multipartIdentifier partitionSpec? COMPUTE STATISTICS + (identifier | FOR COLUMNS identifierSeq | FOR ALL COLUMNS)? #analyze + | ANALYZE TABLES ((FROM | IN) multipartIdentifier)? COMPUTE STATISTICS + (identifier)? #analyzeTables + | ALTER TABLE multipartIdentifier + ADD (COLUMN | COLUMNS) + columns=qualifiedColTypeWithPositionList #addTableColumns + | ALTER TABLE multipartIdentifier + ADD (COLUMN | COLUMNS) + '(' columns=qualifiedColTypeWithPositionList ')' #addTableColumns + | ALTER TABLE table=multipartIdentifier + RENAME COLUMN + from=multipartIdentifier TO to=errorCapturingIdentifier #renameTableColumn + | ALTER TABLE multipartIdentifier + DROP (COLUMN | COLUMNS) + '(' columns=multipartIdentifierList ')' #dropTableColumns + | ALTER TABLE multipartIdentifier + DROP (COLUMN | COLUMNS) columns=multipartIdentifierList #dropTableColumns + | ALTER (TABLE | VIEW) from=multipartIdentifier + RENAME TO to=multipartIdentifier #renameTable + | ALTER (TABLE | VIEW) multipartIdentifier + SET TBLPROPERTIES tablePropertyList #setTableProperties + | ALTER (TABLE | VIEW) multipartIdentifier + UNSET TBLPROPERTIES (IF EXISTS)? tablePropertyList #unsetTableProperties + | ALTER TABLE table=multipartIdentifier + (ALTER | CHANGE) COLUMN? column=multipartIdentifier + alterColumnAction? #alterTableAlterColumn + | ALTER TABLE table=multipartIdentifier partitionSpec? + CHANGE COLUMN? + colName=multipartIdentifier colType colPosition? #hiveChangeColumn + | ALTER TABLE table=multipartIdentifier partitionSpec? + REPLACE COLUMNS + '(' columns=qualifiedColTypeWithPositionList ')' #hiveReplaceColumns + | ALTER TABLE multipartIdentifier (partitionSpec)? + SET SERDE STRING (WITH SERDEPROPERTIES tablePropertyList)? #setTableSerDe + | ALTER TABLE multipartIdentifier (partitionSpec)? + SET SERDEPROPERTIES tablePropertyList #setTableSerDe + | ALTER (TABLE | VIEW) multipartIdentifier ADD (IF NOT EXISTS)? + partitionSpecLocation+ #addTablePartition + | ALTER TABLE multipartIdentifier + from=partitionSpec RENAME TO to=partitionSpec #renameTablePartition + | ALTER (TABLE | VIEW) multipartIdentifier + DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* PURGE? #dropTablePartitions + | ALTER TABLE multipartIdentifier + (partitionSpec)? SET locationSpec #setTableLocation + | ALTER TABLE multipartIdentifier RECOVER PARTITIONS #recoverPartitions + | DROP TABLE (IF EXISTS)? multipartIdentifier PURGE? #dropTable + | DROP VIEW (IF EXISTS)? multipartIdentifier #dropView + | CREATE (OR REPLACE)? (GLOBAL? TEMPORARY)? + VIEW (IF NOT EXISTS)? multipartIdentifier + identifierCommentList? + (commentSpec | + (PARTITIONED ON identifierList) | + (TBLPROPERTIES tablePropertyList))* + AS query #createView + | CREATE (OR REPLACE)? GLOBAL? TEMPORARY VIEW + tableIdentifier ('(' colTypeList ')')? tableProvider + (OPTIONS tablePropertyList)? #createTempViewUsing + | ALTER VIEW multipartIdentifier AS? query #alterViewQuery + | CREATE (OR REPLACE)? TEMPORARY? FUNCTION (IF NOT EXISTS)? + multipartIdentifier AS className=STRING + (USING resource (',' resource)*)? #createFunction + | DROP TEMPORARY? FUNCTION (IF EXISTS)? multipartIdentifier #dropFunction + | EXPLAIN (LOGICAL | FORMATTED | EXTENDED | CODEGEN | COST)? + statement #explain + | SHOW TABLES ((FROM | IN) multipartIdentifier)? + (LIKE? pattern=STRING)? #showTables + | SHOW TABLE EXTENDED ((FROM | IN) ns=multipartIdentifier)? + LIKE pattern=STRING partitionSpec? #showTableExtended + | SHOW TBLPROPERTIES table=multipartIdentifier + ('(' key=tablePropertyKey ')')? #showTblProperties + | SHOW COLUMNS (FROM | IN) table=multipartIdentifier + ((FROM | IN) ns=multipartIdentifier)? #showColumns + | SHOW VIEWS ((FROM | IN) multipartIdentifier)? + (LIKE? pattern=STRING)? #showViews + | SHOW PARTITIONS multipartIdentifier partitionSpec? #showPartitions + | SHOW identifier? FUNCTIONS + (LIKE? (multipartIdentifier | pattern=STRING))? #showFunctions + | SHOW CREATE TABLE multipartIdentifier (AS SERDE)? #showCreateTable + | SHOW CURRENT NAMESPACE #showCurrentNamespace + | (DESC | DESCRIBE) FUNCTION EXTENDED? describeFuncName #describeFunction + | (DESC | DESCRIBE) namespace EXTENDED? + multipartIdentifier #describeNamespace + | (DESC | DESCRIBE) TABLE? option=(EXTENDED | FORMATTED)? + multipartIdentifier partitionSpec? describeColName? #describeRelation + | (DESC | DESCRIBE) QUERY? query #describeQuery + | COMMENT ON namespace multipartIdentifier IS + comment=(STRING | NULL) #commentNamespace + | COMMENT ON TABLE multipartIdentifier IS comment=(STRING | NULL) #commentTable + | REFRESH TABLE multipartIdentifier #refreshTable + | REFRESH FUNCTION multipartIdentifier #refreshFunction + | REFRESH (STRING | .*?) #refreshResource + | CACHE LAZY? TABLE multipartIdentifier + (OPTIONS options=tablePropertyList)? (AS? query)? #cacheTable + | UNCACHE TABLE (IF EXISTS)? multipartIdentifier #uncacheTable + | CLEAR CACHE #clearCache + | LOAD DATA LOCAL? INPATH path=STRING OVERWRITE? INTO TABLE + multipartIdentifier partitionSpec? #loadData + | TRUNCATE TABLE multipartIdentifier partitionSpec? #truncateTable + | MSCK REPAIR TABLE multipartIdentifier + (option=(ADD|DROP|SYNC) PARTITIONS)? #repairTable + | op=(ADD | LIST) identifier .*? #manageResource + | SET ROLE .*? #failNativeCommand + | SET TIME ZONE interval #setTimeZone + | SET TIME ZONE timezone=(STRING | LOCAL) #setTimeZone + | SET TIME ZONE .*? #setTimeZone + | SET configKey EQ configValue #setQuotedConfiguration + | SET configKey (EQ .*?)? #setQuotedConfiguration + | SET .*? EQ configValue #setQuotedConfiguration + | SET .*? #setConfiguration + | RESET configKey #resetQuotedConfiguration + | RESET .*? #resetConfiguration + | unsupportedHiveNativeCommands .*? #failNativeCommand + ; + +configKey + : quotedIdentifier + ; + +configValue + : quotedIdentifier + ; + +unsupportedHiveNativeCommands + : kw1=CREATE kw2=ROLE + | kw1=DROP kw2=ROLE + | kw1=GRANT kw2=ROLE? + | kw1=REVOKE kw2=ROLE? + | kw1=SHOW kw2=GRANT + | kw1=SHOW kw2=ROLE kw3=GRANT? + | kw1=SHOW kw2=PRINCIPALS + | kw1=SHOW kw2=ROLES + | kw1=SHOW kw2=CURRENT kw3=ROLES + | kw1=EXPORT kw2=TABLE + | kw1=IMPORT kw2=TABLE + | kw1=SHOW kw2=COMPACTIONS + | kw1=SHOW kw2=CREATE kw3=TABLE + | kw1=SHOW kw2=TRANSACTIONS + | kw1=SHOW kw2=INDEXES + | kw1=SHOW kw2=LOCKS + | kw1=CREATE kw2=INDEX + | kw1=DROP kw2=INDEX + | kw1=ALTER kw2=INDEX + | kw1=LOCK kw2=TABLE + | kw1=LOCK kw2=DATABASE + | kw1=UNLOCK kw2=TABLE + | kw1=UNLOCK kw2=DATABASE + | kw1=CREATE kw2=TEMPORARY kw3=MACRO + | kw1=DROP kw2=TEMPORARY kw3=MACRO + | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=CLUSTERED + | kw1=ALTER kw2=TABLE tableIdentifier kw3=CLUSTERED kw4=BY + | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=SORTED + | kw1=ALTER kw2=TABLE tableIdentifier kw3=SKEWED kw4=BY + | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=SKEWED + | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=STORED kw5=AS kw6=DIRECTORIES + | kw1=ALTER kw2=TABLE tableIdentifier kw3=SET kw4=SKEWED kw5=LOCATION + | kw1=ALTER kw2=TABLE tableIdentifier kw3=EXCHANGE kw4=PARTITION + | kw1=ALTER kw2=TABLE tableIdentifier kw3=ARCHIVE kw4=PARTITION + | kw1=ALTER kw2=TABLE tableIdentifier kw3=UNARCHIVE kw4=PARTITION + | kw1=ALTER kw2=TABLE tableIdentifier kw3=TOUCH + | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=COMPACT + | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=CONCATENATE + | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=SET kw4=FILEFORMAT + | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=REPLACE kw4=COLUMNS + | kw1=START kw2=TRANSACTION + | kw1=COMMIT + | kw1=ROLLBACK + | kw1=DFS + ; + +createTableHeader + : CREATE TEMPORARY? EXTERNAL? TABLE (IF NOT EXISTS)? multipartIdentifier + ; + +replaceTableHeader + : (CREATE OR)? REPLACE TABLE multipartIdentifier + ; + +bucketSpec + : CLUSTERED BY identifierList + (SORTED BY orderedIdentifierList)? + INTO INTEGER_VALUE BUCKETS + ; + +skewSpec + : SKEWED BY identifierList + ON (constantList | nestedConstantList) + (STORED AS DIRECTORIES)? + ; + +locationSpec + : LOCATION STRING + ; + +commentSpec + : COMMENT STRING + ; + +query + : ctes? queryTerm queryOrganization + ; + +insertInto + : INSERT OVERWRITE TABLE? multipartIdentifier (partitionSpec (IF NOT EXISTS)?)? identifierList? #insertOverwriteTable + | INSERT INTO TABLE? multipartIdentifier partitionSpec? (IF NOT EXISTS)? identifierList? #insertIntoTable + | INSERT OVERWRITE LOCAL? DIRECTORY path=STRING rowFormat? createFileFormat? #insertOverwriteHiveDir + | INSERT OVERWRITE LOCAL? DIRECTORY (path=STRING)? tableProvider (OPTIONS options=tablePropertyList)? #insertOverwriteDir + ; + +partitionSpecLocation + : partitionSpec locationSpec? + ; + +partitionSpec + : PARTITION '(' partitionVal (',' partitionVal)* ')' + ; + +partitionVal + : identifier (EQ constant)? + ; + +namespace + : NAMESPACE + | DATABASE + | SCHEMA + ; + +describeFuncName + : qualifiedName + | STRING + | comparisonOperator + | arithmeticOperator + | predicateOperator + ; + +describeColName + : nameParts+=identifier ('.' nameParts+=identifier)* + ; + +ctes + : WITH namedQuery (',' namedQuery)* + ; + +namedQuery + : name=errorCapturingIdentifier (columnAliases=identifierList)? AS? '(' query ')' + ; + +tableProvider + : USING multipartIdentifier + ; + +createTableClauses + :((OPTIONS options=tablePropertyList) | + (PARTITIONED BY partitioning=partitionFieldList) | + skewSpec | + bucketSpec | + rowFormat | + createFileFormat | + locationSpec | + commentSpec | + (TBLPROPERTIES tableProps=tablePropertyList))* + ; + +tablePropertyList + : '(' tableProperty (',' tableProperty)* ')' + ; + +tableProperty + : key=tablePropertyKey (EQ? value=tablePropertyValue)? + ; + +tablePropertyKey + : identifier ('.' identifier)* + | STRING + ; + +tablePropertyValue + : INTEGER_VALUE + | DECIMAL_VALUE + | booleanValue + | STRING + ; + +constantList + : '(' constant (',' constant)* ')' + ; + +nestedConstantList + : '(' constantList (',' constantList)* ')' + ; + +createFileFormat + : STORED AS fileFormat + | STORED BY storageHandler + ; + +fileFormat + : INPUTFORMAT inFmt=STRING OUTPUTFORMAT outFmt=STRING #tableFileFormat + | identifier #genericFileFormat + ; + +storageHandler + : STRING (WITH SERDEPROPERTIES tablePropertyList)? + ; + +resource + : identifier STRING + ; + +dmlStatementNoWith + : insertInto queryTerm queryOrganization #singleInsertQuery + | fromClause multiInsertQueryBody+ #multiInsertQuery + | DELETE FROM multipartIdentifier tableAlias whereClause? #deleteFromTable + | UPDATE multipartIdentifier tableAlias setClause whereClause? #updateTable + | MERGE INTO target=multipartIdentifier targetAlias=tableAlias + USING (source=multipartIdentifier | + '(' sourceQuery=query')') sourceAlias=tableAlias + ON mergeCondition=booleanExpression + matchedClause* + notMatchedClause* #mergeIntoTable + ; + +queryOrganization + : (ORDER BY order+=sortItem (',' order+=sortItem)*)? + (CLUSTER BY clusterBy+=expression (',' clusterBy+=expression)*)? + (DISTRIBUTE BY distributeBy+=expression (',' distributeBy+=expression)*)? + (SORT BY sort+=sortItem (',' sort+=sortItem)*)? + windowClause? + (LIMIT (ALL | limit=expression))? + ; + +multiInsertQueryBody + : insertInto fromStatementBody + ; + +queryTerm + : queryPrimary #queryTermDefault + | left=queryTerm {legacy_setops_precedence_enabled}? + operator=(INTERSECT | UNION | EXCEPT | SETMINUS) setQuantifier? right=queryTerm #setOperation + | left=queryTerm {!legacy_setops_precedence_enabled}? + operator=INTERSECT setQuantifier? right=queryTerm #setOperation + | left=queryTerm {!legacy_setops_precedence_enabled}? + operator=(UNION | EXCEPT | SETMINUS) setQuantifier? right=queryTerm #setOperation + ; + +queryPrimary + : querySpecification #queryPrimaryDefault + | fromStatement #fromStmt + | TABLE multipartIdentifier #table + | inlineTable #inlineTableDefault1 + | '(' query ')' #subquery + ; + +sortItem + : expression ordering=(ASC | DESC)? (NULLS nullOrder=(LAST | FIRST))? + ; + +fromStatement + : fromClause fromStatementBody+ + ; + +fromStatementBody + : transformClause + whereClause? + queryOrganization + | selectClause + lateralView* + whereClause? + aggregationClause? + havingClause? + windowClause? + queryOrganization + ; + +querySpecification + : transformClause + fromClause? + lateralView* + whereClause? + aggregationClause? + havingClause? + windowClause? #transformQuerySpecification + | selectClause + fromClause? + lateralView* + whereClause? + aggregationClause? + havingClause? + windowClause? #regularQuerySpecification + ; + +transformClause + : (SELECT kind=TRANSFORM '(' setQuantifier? expressionSeq ')' + | kind=MAP setQuantifier? expressionSeq + | kind=REDUCE setQuantifier? expressionSeq) + inRowFormat=rowFormat? + (RECORDWRITER recordWriter=STRING)? + USING script=STRING + (AS (identifierSeq | colTypeList | ('(' (identifierSeq | colTypeList) ')')))? + outRowFormat=rowFormat? + (RECORDREADER recordReader=STRING)? + ; + +selectClause + : SELECT (hints+=hint)* setQuantifier? namedExpressionSeq + ; + +setClause + : SET assignmentList + ; + +matchedClause + : WHEN MATCHED (AND matchedCond=booleanExpression)? THEN matchedAction + ; +notMatchedClause + : WHEN NOT MATCHED (AND notMatchedCond=booleanExpression)? THEN notMatchedAction + ; + +matchedAction + : DELETE + | UPDATE SET ASTERISK + | UPDATE SET assignmentList + ; + +notMatchedAction + : INSERT ASTERISK + | INSERT '(' columns=multipartIdentifierList ')' + VALUES '(' expression (',' expression)* ')' + ; + +assignmentList + : assignment (',' assignment)* + ; + +assignment + : key=multipartIdentifier EQ value=expression + ; + +whereClause + : WHERE booleanExpression + ; + +havingClause + : HAVING booleanExpression + ; + +hint + : '/*+' hintStatements+=hintStatement (','? hintStatements+=hintStatement)* '*/' + ; + +hintStatement + : hintName=identifier + | hintName=identifier '(' parameters+=primaryExpression (',' parameters+=primaryExpression)* ')' + ; + +fromClause + : FROM relation (',' relation)* lateralView* pivotClause? + ; + +temporalClause + : FOR? (SYSTEM_TIME | TIMESTAMP) AS OF timestamp=valueExpression + | FOR? (SYSTEM_VERSION | VERSION) AS OF version=(INTEGER_VALUE | STRING) + ; + +aggregationClause + : GROUP BY groupingExpressionsWithGroupingAnalytics+=groupByClause + (',' groupingExpressionsWithGroupingAnalytics+=groupByClause)* + | GROUP BY groupingExpressions+=expression (',' groupingExpressions+=expression)* ( + WITH kind=ROLLUP + | WITH kind=CUBE + | kind=GROUPING SETS '(' groupingSet (',' groupingSet)* ')')? + ; + +groupByClause + : groupingAnalytics + | expression + ; + +groupingAnalytics + : (ROLLUP | CUBE) '(' groupingSet (',' groupingSet)* ')' + | GROUPING SETS '(' groupingElement (',' groupingElement)* ')' + ; + +groupingElement + : groupingAnalytics + | groupingSet + ; + +groupingSet + : '(' (expression (',' expression)*)? ')' + | expression + ; + +pivotClause + : PIVOT '(' aggregates=namedExpressionSeq FOR pivotColumn IN '(' pivotValues+=pivotValue (',' pivotValues+=pivotValue)* ')' ')' + ; + +pivotColumn + : identifiers+=identifier + | '(' identifiers+=identifier (',' identifiers+=identifier)* ')' + ; + +pivotValue + : expression (AS? identifier)? + ; + +lateralView + : LATERAL VIEW (OUTER)? qualifiedName '(' (expression (',' expression)*)? ')' tblName=identifier (AS? colName+=identifier (',' colName+=identifier)*)? + ; + +setQuantifier + : DISTINCT + | ALL + ; + +relation + : LATERAL? relationPrimary joinRelation* + ; + +joinRelation + : (joinType) JOIN LATERAL? right=relationPrimary joinCriteria? + | NATURAL joinType JOIN LATERAL? right=relationPrimary + ; + +joinType + : INNER? + | CROSS + | LEFT OUTER? + | LEFT? SEMI + | RIGHT OUTER? + | FULL OUTER? + | LEFT? ANTI + ; + +joinCriteria + : ON booleanExpression + | USING identifierList + ; + +sample + : TABLESAMPLE '(' sampleMethod? ')' + ; + +sampleMethod + : negativeSign=MINUS? percentage=(INTEGER_VALUE | DECIMAL_VALUE) PERCENTLIT #sampleByPercentile + | expression ROWS #sampleByRows + | sampleType=BUCKET numerator=INTEGER_VALUE OUT OF denominator=INTEGER_VALUE + (ON (identifier | qualifiedName '(' ')'))? #sampleByBucket + | bytes=expression #sampleByBytes + ; + +identifierList + : '(' identifierSeq ')' + ; + +identifierSeq + : ident+=errorCapturingIdentifier (',' ident+=errorCapturingIdentifier)* + ; + +orderedIdentifierList + : '(' orderedIdentifier (',' orderedIdentifier)* ')' + ; + +orderedIdentifier + : ident=errorCapturingIdentifier ordering=(ASC | DESC)? + ; + +identifierCommentList + : '(' identifierComment (',' identifierComment)* ')' + ; + +identifierComment + : identifier commentSpec? + ; + +relationPrimary + : multipartIdentifier temporalClause? + sample? tableAlias #tableName + | '(' query ')' sample? tableAlias #aliasedQuery + | '(' relation ')' sample? tableAlias #aliasedRelation + | inlineTable #inlineTableDefault2 + | functionTable #tableValuedFunction + ; + +inlineTable + : VALUES expression (',' expression)* tableAlias + ; + +functionTable + : funcName=functionName '(' (expression (',' expression)*)? ')' tableAlias + ; + +tableAlias + : (AS? strictIdentifier identifierList?)? + ; + +rowFormat + : ROW FORMAT SERDE name=STRING (WITH SERDEPROPERTIES props=tablePropertyList)? #rowFormatSerde + | ROW FORMAT DELIMITED + (FIELDS TERMINATED BY fieldsTerminatedBy=STRING (ESCAPED BY escapedBy=STRING)?)? + (COLLECTION ITEMS TERMINATED BY collectionItemsTerminatedBy=STRING)? + (MAP KEYS TERMINATED BY keysTerminatedBy=STRING)? + (LINES TERMINATED BY linesSeparatedBy=STRING)? + (NULL DEFINED AS nullDefinedAs=STRING)? #rowFormatDelimited + ; + +multipartIdentifierList + : multipartIdentifier (',' multipartIdentifier)* + ; + +multipartIdentifier + : parts+=errorCapturingIdentifier ('.' parts+=errorCapturingIdentifier)* + ; + +tableIdentifier + : (db=errorCapturingIdentifier '.')? table=errorCapturingIdentifier + ; + +functionIdentifier + : (db=errorCapturingIdentifier '.')? function=errorCapturingIdentifier + ; + +multipartIdentifierPropertyList + : multipartIdentifierProperty (COMMA multipartIdentifierProperty)* + ; + +multipartIdentifierProperty + : multipartIdentifier (OPTIONS options=propertyList)? + ; + +propertyList + : LEFT_PAREN property (COMMA property)* RIGHT_PAREN + ; + +property + : key=propertyKey (EQ? value=propertyValue)? + ; + +propertyKey + : identifier (DOT identifier)* + | STRING + ; + +propertyValue + : INTEGER_VALUE + | DECIMAL_VALUE + | booleanValue + | STRING + ; + +namedExpression + : expression (AS? (name=errorCapturingIdentifier | identifierList))? + ; + +namedExpressionSeq + : namedExpression (',' namedExpression)* + ; + +partitionFieldList + : '(' fields+=partitionField (',' fields+=partitionField)* ')' + ; + +partitionField + : transform #partitionTransform + | colType #partitionColumn + ; + +transform + : qualifiedName #identityTransform + | transformName=identifier + '(' argument+=transformArgument (',' argument+=transformArgument)* ')' #applyTransform + ; + +transformArgument + : qualifiedName + | constant + ; + +expression + : booleanExpression + ; + +expressionSeq + : expression (',' expression)* + ; + +booleanExpression + : NOT booleanExpression #logicalNot + | EXISTS '(' query ')' #exists + | valueExpression predicate? #predicated + | left=booleanExpression operator=AND right=booleanExpression #logicalBinary + | left=booleanExpression operator=OR right=booleanExpression #logicalBinary + ; + +predicate + : NOT? kind=BETWEEN lower=valueExpression AND upper=valueExpression + | NOT? kind=IN '(' expression (',' expression)* ')' + | NOT? kind=IN '(' query ')' + | NOT? kind=RLIKE pattern=valueExpression + | NOT? kind=LIKE quantifier=(ANY | SOME | ALL) ('('')' | '(' expression (',' expression)* ')') + | NOT? kind=LIKE pattern=valueExpression (ESCAPE escapeChar=STRING)? + | IS NOT? kind=NULL + | IS NOT? kind=(TRUE | FALSE | UNKNOWN) + | IS NOT? kind=DISTINCT FROM right=valueExpression + ; + +valueExpression + : primaryExpression #valueExpressionDefault + | operator=(MINUS | PLUS | TILDE) valueExpression #arithmeticUnary + | left=valueExpression operator=(ASTERISK | SLASH | PERCENT | DIV) right=valueExpression #arithmeticBinary + | left=valueExpression operator=(PLUS | MINUS | CONCAT_PIPE) right=valueExpression #arithmeticBinary + | left=valueExpression operator=AMPERSAND right=valueExpression #arithmeticBinary + | left=valueExpression operator=HAT right=valueExpression #arithmeticBinary + | left=valueExpression operator=PIPE right=valueExpression #arithmeticBinary + | left=valueExpression comparisonOperator right=valueExpression #comparison + ; + +primaryExpression + : name=(CURRENT_DATE | CURRENT_TIMESTAMP | CURRENT_USER) #currentLike + | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase + | CASE value=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase + | name=(CAST | TRY_CAST) '(' expression AS dataType ')' #cast + | STRUCT '(' (argument+=namedExpression (',' argument+=namedExpression)*)? ')' #struct + | FIRST '(' expression (IGNORE NULLS)? ')' #first + | LAST '(' expression (IGNORE NULLS)? ')' #last + | POSITION '(' substr=valueExpression IN str=valueExpression ')' #position + | constant #constantDefault + | ASTERISK #star + | qualifiedName '.' ASTERISK #star + | '(' namedExpression (',' namedExpression)+ ')' #rowConstructor + | '(' query ')' #subqueryExpression + | functionName '(' (setQuantifier? argument+=expression (',' argument+=expression)*)? ')' + (FILTER '(' WHERE where=booleanExpression ')')? + (nullsOption=(IGNORE | RESPECT) NULLS)? ( OVER windowSpec)? #functionCall + | identifier '->' expression #lambda + | '(' identifier (',' identifier)+ ')' '->' expression #lambda + | value=primaryExpression '[' index=valueExpression ']' #subscript + | identifier #columnReference + | base=primaryExpression '.' fieldName=identifier #dereference + | '(' expression ')' #parenthesizedExpression + | EXTRACT '(' field=identifier FROM source=valueExpression ')' #extract + | (SUBSTR | SUBSTRING) '(' str=valueExpression (FROM | ',') pos=valueExpression + ((FOR | ',') len=valueExpression)? ')' #substring + | TRIM '(' trimOption=(BOTH | LEADING | TRAILING)? (trimStr=valueExpression)? + FROM srcStr=valueExpression ')' #trim + | OVERLAY '(' input=valueExpression PLACING replace=valueExpression + FROM position=valueExpression (FOR length=valueExpression)? ')' #overlay + ; + +constant + : NULL #nullLiteral + | interval #intervalLiteral + | identifier STRING #typeConstructor + | number #numericLiteral + | booleanValue #booleanLiteral + | STRING+ #stringLiteral + ; + +comparisonOperator + : EQ | NEQ | NEQJ | LT | LTE | GT | GTE | NSEQ + ; + +arithmeticOperator + : PLUS | MINUS | ASTERISK | SLASH | PERCENT | DIV | TILDE | AMPERSAND | PIPE | CONCAT_PIPE | HAT + ; + +predicateOperator + : OR | AND | IN | NOT + ; + +booleanValue + : TRUE | FALSE + ; + +interval + : INTERVAL (errorCapturingMultiUnitsInterval | errorCapturingUnitToUnitInterval)? + ; + +errorCapturingMultiUnitsInterval + : body=multiUnitsInterval unitToUnitInterval? + ; + +multiUnitsInterval + : (intervalValue unit+=identifier)+ + ; + +errorCapturingUnitToUnitInterval + : body=unitToUnitInterval (error1=multiUnitsInterval | error2=unitToUnitInterval)? + ; + +unitToUnitInterval + : value=intervalValue from=identifier TO to=identifier + ; + +intervalValue + : (PLUS | MINUS)? (INTEGER_VALUE | DECIMAL_VALUE | STRING) + ; + +colPosition + : position=FIRST | position=AFTER afterCol=errorCapturingIdentifier + ; + +dataType + : complex=ARRAY '<' dataType '>' #complexDataType + | complex=MAP '<' dataType ',' dataType '>' #complexDataType + | complex=STRUCT ('<' complexColTypeList? '>' | NEQ) #complexDataType + | INTERVAL from=(YEAR | MONTH) (TO to=MONTH)? #yearMonthIntervalDataType + | INTERVAL from=(DAY | HOUR | MINUTE | SECOND) + (TO to=(HOUR | MINUTE | SECOND))? #dayTimeIntervalDataType + | identifier ('(' INTEGER_VALUE (',' INTEGER_VALUE)* ')')? #primitiveDataType + ; + +qualifiedColTypeWithPositionList + : qualifiedColTypeWithPosition (',' qualifiedColTypeWithPosition)* + ; + +qualifiedColTypeWithPosition + : name=multipartIdentifier dataType (NOT NULL)? commentSpec? colPosition? + ; + +colTypeList + : colType (',' colType)* + ; + +colType + : colName=errorCapturingIdentifier dataType (NOT NULL)? commentSpec? + ; + +complexColTypeList + : complexColType (',' complexColType)* + ; + +complexColType + : identifier ':'? dataType (NOT NULL)? commentSpec? + ; + +whenClause + : WHEN condition=expression THEN result=expression + ; + +windowClause + : WINDOW namedWindow (',' namedWindow)* + ; + +namedWindow + : name=errorCapturingIdentifier AS windowSpec + ; + +windowSpec + : name=errorCapturingIdentifier #windowRef + | '('name=errorCapturingIdentifier')' #windowRef + | '(' + ( CLUSTER BY partition+=expression (',' partition+=expression)* + | ((PARTITION | DISTRIBUTE) BY partition+=expression (',' partition+=expression)*)? + ((ORDER | SORT) BY sortItem (',' sortItem)*)?) + windowFrame? + ')' #windowDef + ; + +windowFrame + : frameType=RANGE start=frameBound + | frameType=ROWS start=frameBound + | frameType=RANGE BETWEEN start=frameBound AND end=frameBound + | frameType=ROWS BETWEEN start=frameBound AND end=frameBound + ; + +frameBound + : UNBOUNDED boundType=(PRECEDING | FOLLOWING) + | boundType=CURRENT ROW + | expression boundType=(PRECEDING | FOLLOWING) + ; + +qualifiedNameList + : qualifiedName (',' qualifiedName)* + ; + +functionName + : qualifiedName + | FILTER + | LEFT + | RIGHT + ; + +qualifiedName + : identifier ('.' identifier)* + ; + +// this rule is used for explicitly capturing wrong identifiers such as test-table, which should actually be `test-table` +// replace identifier with errorCapturingIdentifier where the immediate follow symbol is not an expression, otherwise +// valid expressions such as "a-b" can be recognized as an identifier +errorCapturingIdentifier + : identifier errorCapturingIdentifierExtra + ; + +// extra left-factoring grammar +errorCapturingIdentifierExtra + : (MINUS identifier)+ #errorIdent + | #realIdent + ; + +identifier + : strictIdentifier + | {!SQL_standard_keyword_behavior}? strictNonReserved + ; + +strictIdentifier + : IDENTIFIER #unquotedIdentifier + | quotedIdentifier #quotedIdentifierAlternative + | {SQL_standard_keyword_behavior}? ansiNonReserved #unquotedIdentifier + | {!SQL_standard_keyword_behavior}? nonReserved #unquotedIdentifier + ; + +quotedIdentifier + : BACKQUOTED_IDENTIFIER + ; + +number + : {!legacy_exponent_literal_as_decimal_enabled}? MINUS? EXPONENT_VALUE #exponentLiteral + | {!legacy_exponent_literal_as_decimal_enabled}? MINUS? DECIMAL_VALUE #decimalLiteral + | {legacy_exponent_literal_as_decimal_enabled}? MINUS? (EXPONENT_VALUE | DECIMAL_VALUE) #legacyDecimalLiteral + | MINUS? INTEGER_VALUE #integerLiteral + | MINUS? BIGINT_LITERAL #bigIntLiteral + | MINUS? SMALLINT_LITERAL #smallIntLiteral + | MINUS? TINYINT_LITERAL #tinyIntLiteral + | MINUS? DOUBLE_LITERAL #doubleLiteral + | MINUS? FLOAT_LITERAL #floatLiteral + | MINUS? BIGDECIMAL_LITERAL #bigDecimalLiteral + ; + +alterColumnAction + : TYPE dataType + | commentSpec + | colPosition + | setOrDrop=(SET | DROP) NOT NULL + ; + +// When `SQL_standard_keyword_behavior=true`, there are 2 kinds of keywords in Spark SQL. +// - Reserved keywords: +// Keywords that are reserved and can't be used as identifiers for table, view, column, +// function, alias, etc. +// - Non-reserved keywords: +// Keywords that have a special meaning only in particular contexts and can be used as +// identifiers in other contexts. For example, `EXPLAIN SELECT ...` is a command, but EXPLAIN +// can be used as identifiers in other places. +// You can find the full keywords list by searching "Start of the keywords list" in this file. +// The non-reserved keywords are listed below. Keywords not in this list are reserved keywords. +ansiNonReserved +//--ANSI-NON-RESERVED-START + : ADD + | AFTER + | ALTER + | ANALYZE + | ANTI + | ARCHIVE + | ARRAY + | ASC + | AT + | BETWEEN + | BUCKET + | BUCKETS + | BY + | CACHE + | CASCADE + | CHANGE + | CLEAR + | CLUSTER + | CLUSTERED + | CODEGEN + | COLLECTION + | COLUMNS + | COMMENT + | COMMIT + | COMPACT + | COMPACTIONS + | COMPUTE + | CONCATENATE + | COST + | CUBE + | CURRENT + | DATA + | DATABASE + | DATABASES + | DAY + | DBPROPERTIES + | DEFINED + | DELETE + | DELIMITED + | DESC + | DESCRIBE + | DFS + | DIRECTORIES + | DIRECTORY + | DISTRIBUTE + | DIV + | DROP + | ESCAPED + | EXCHANGE + | EXISTS + | EXPLAIN + | EXPORT + | EXTENDED + | EXTERNAL + | EXTRACT + | FIELDS + | FILEFORMAT + | FIRST + | FOLLOWING + | FORMAT + | FORMATTED + | FUNCTION + | FUNCTIONS + | GLOBAL + | GROUPING + | HOUR + | IF + | IGNORE + | IMPORT + | INDEX + | INDEXES + | INPATH + | INPUTFORMAT + | INSERT + | INTERVAL + | ITEMS + | KEYS + | LAST + | LAZY + | LIKE + | LIMIT + | LINES + | LIST + | LOAD + | LOCAL + | LOCATION + | LOCK + | LOCKS + | LOGICAL + | MACRO + | MAP + | MATCHED + | MERGE + | MINUTE + | MONTH + | MSCK + | NAMESPACE + | NAMESPACES + | NO + | NULLS + | OF + | OPTION + | OPTIONS + | OUT + | OUTPUTFORMAT + | OVER + | OVERLAY + | OVERWRITE + | PARTITION + | PARTITIONED + | PARTITIONS + | PERCENTLIT + | PIVOT + | PLACING + | POSITION + | PRECEDING + | PRINCIPALS + | PROPERTIES + | PURGE + | QUERY + | RANGE + | RECORDREADER + | RECORDWRITER + | RECOVER + | REDUCE + | REFRESH + | RENAME + | REPAIR + | REPLACE + | RESET + | RESPECT + | RESTRICT + | REVOKE + | RLIKE + | ROLE + | ROLES + | ROLLBACK + | ROLLUP + | ROW + | ROWS + | SCHEMA + | SECOND + | SEMI + | SEPARATED + | SERDE + | SERDEPROPERTIES + | SET + | SETMINUS + | SETS + | SHOW + | SKEWED + | SORT + | SORTED + | START + | STATISTICS + | STORED + | STRATIFY + | STRUCT + | SUBSTR + | SUBSTRING + | SYNC + | TABLES + | TABLESAMPLE + | TBLPROPERTIES + | TEMPORARY + | TERMINATED + | TOUCH + | TRANSACTION + | TRANSACTIONS + | TRANSFORM + | TRIM + | TRUE + | TRUNCATE + | TRY_CAST + | TYPE + | UNARCHIVE + | UNBOUNDED + | UNCACHE + | UNLOCK + | UNSET + | UPDATE + | USE + | VALUES + | VIEW + | VIEWS + | WINDOW + | YEAR + | ZONE +//--ANSI-NON-RESERVED-END + ; + +// When `SQL_standard_keyword_behavior=false`, there are 2 kinds of keywords in Spark SQL. +// - Non-reserved keywords: +// Same definition as the one when `SQL_standard_keyword_behavior=true`. +// - Strict-non-reserved keywords: +// A strict version of non-reserved keywords, which can not be used as table alias. +// You can find the full keywords list by searching "Start of the keywords list" in this file. +// The strict-non-reserved keywords are listed in `strictNonReserved`. +// The non-reserved keywords are listed in `nonReserved`. +// These 2 together contain all the keywords. +strictNonReserved + : ANTI + | CROSS + | EXCEPT + | FULL + | INNER + | INTERSECT + | JOIN + | LATERAL + | LEFT + | NATURAL + | ON + | RIGHT + | SEMI + | SETMINUS + | UNION + | USING + ; + +nonReserved +//--DEFAULT-NON-RESERVED-START + : ADD + | AFTER + | ALL + | ALTER + | ANALYZE + | AND + | ANY + | ARCHIVE + | ARRAY + | AS + | ASC + | AT + | AUTHORIZATION + | BETWEEN + | BOTH + | BUCKET + | BUCKETS + | BY + | CACHE + | CASCADE + | CASE + | CAST + | CHANGE + | CHECK + | CLEAR + | CLUSTER + | CLUSTERED + | CODEGEN + | COLLATE + | COLLECTION + | COLUMN + | COLUMNS + | COMMENT + | COMMIT + | COMPACT + | COMPACTIONS + | COMPUTE + | CONCATENATE + | CONSTRAINT + | COST + | CREATE + | CUBE + | CURRENT + | CURRENT_DATE + | CURRENT_TIME + | CURRENT_TIMESTAMP + | CURRENT_USER + | DATA + | DATABASE + | DATABASES + | DAY + | DBPROPERTIES + | DEFINED + | DELETE + | DELIMITED + | DESC + | DESCRIBE + | DFS + | DIRECTORIES + | DIRECTORY + | DISTINCT + | DISTRIBUTE + | DIV + | DROP + | ELSE + | END + | ESCAPE + | ESCAPED + | EXCHANGE + | EXISTS + | EXPLAIN + | EXPORT + | EXTENDED + | EXTERNAL + | EXTRACT + | FALSE + | FETCH + | FILTER + | FIELDS + | FILEFORMAT + | FIRST + | FOLLOWING + | FOR + | FOREIGN + | FORMAT + | FORMATTED + | FROM + | FUNCTION + | FUNCTIONS + | GLOBAL + | GRANT + | GROUP + | GROUPING + | HAVING + | HOUR + | IF + | IGNORE + | IMPORT + | IN + | INDEX + | INDEXES + | INPATH + | INPUTFORMAT + | INSERT + | INTERVAL + | INTO + | IS + | ITEMS + | KEYS + | LAST + | LAZY + | LEADING + | LIKE + | LIMIT + | LINES + | LIST + | LOAD + | LOCAL + | LOCATION + | LOCK + | LOCKS + | LOGICAL + | MACRO + | MAP + | MATCHED + | MERGE + | MINUTE + | MONTH + | MSCK + | NAMESPACE + | NAMESPACES + | NO + | NOT + | NULL + | NULLS + | OF + | ONLY + | OPTION + | OPTIONS + | OR + | ORDER + | OUT + | OUTER + | OUTPUTFORMAT + | OVER + | OVERLAPS + | OVERLAY + | OVERWRITE + | PARTITION + | PARTITIONED + | PARTITIONS + | PERCENTLIT + | PIVOT + | PLACING + | POSITION + | PRECEDING + | PRIMARY + | PRINCIPALS + | PROPERTIES + | PURGE + | QUERY + | RANGE + | RECORDREADER + | RECORDWRITER + | RECOVER + | REDUCE + | REFERENCES + | REFRESH + | RENAME + | REPAIR + | REPLACE + | RESET + | RESPECT + | RESTRICT + | REVOKE + | RLIKE + | ROLE + | ROLES + | ROLLBACK + | ROLLUP + | ROW + | ROWS + | SCHEMA + | SECOND + | SELECT + | SEPARATED + | SERDE + | SERDEPROPERTIES + | SESSION_USER + | SET + | SETS + | SHOW + | SKEWED + | SOME + | SORT + | SORTED + | START + | STATISTICS + | STORED + | STRATIFY + | STRUCT + | SUBSTR + | SUBSTRING + | SYNC + | TABLE + | TABLES + | TABLESAMPLE + | TBLPROPERTIES + | TEMPORARY + | TERMINATED + | THEN + | TIME + | TO + | TOUCH + | TRAILING + | TRANSACTION + | TRANSACTIONS + | TRANSFORM + | TRIM + | TRUE + | TRUNCATE + | TRY_CAST + | TYPE + | UNARCHIVE + | UNBOUNDED + | UNCACHE + | UNIQUE + | UNKNOWN + | UNLOCK + | UNSET + | UPDATE + | USE + | USER + | VALUES + | VIEW + | VIEWS + | WHEN + | WHERE + | WINDOW + | WITH + | YEAR + | ZONE + | SYSTEM_VERSION + | VERSION + | SYSTEM_TIME + | TIMESTAMP +//--DEFAULT-NON-RESERVED-END + ; + +// NOTE: If you add a new token in the list below, you should update the list of keywords +// and reserved tag in `docs/sql-ref-ansi-compliance.md#sql-keywords`. + +//============================ +// Start of the keywords list +//============================ +//--SPARK-KEYWORD-LIST-START +ADD: 'ADD'; +AFTER: 'AFTER'; +ALL: 'ALL'; +ALTER: 'ALTER'; +ANALYZE: 'ANALYZE'; +AND: 'AND'; +ANTI: 'ANTI'; +ANY: 'ANY'; +ARCHIVE: 'ARCHIVE'; +ARRAY: 'ARRAY'; +AS: 'AS'; +ASC: 'ASC'; +AT: 'AT'; +AUTHORIZATION: 'AUTHORIZATION'; +BETWEEN: 'BETWEEN'; +BOTH: 'BOTH'; +BUCKET: 'BUCKET'; +BUCKETS: 'BUCKETS'; +BY: 'BY'; +CACHE: 'CACHE'; +CASCADE: 'CASCADE'; +CASE: 'CASE'; +CAST: 'CAST'; +CHANGE: 'CHANGE'; +CHECK: 'CHECK'; +CLEAR: 'CLEAR'; +CLUSTER: 'CLUSTER'; +CLUSTERED: 'CLUSTERED'; +CODEGEN: 'CODEGEN'; +COLLATE: 'COLLATE'; +COLLECTION: 'COLLECTION'; +COLUMN: 'COLUMN'; +COLUMNS: 'COLUMNS'; +COMMENT: 'COMMENT'; +COMMIT: 'COMMIT'; +COMPACT: 'COMPACT'; +COMPACTIONS: 'COMPACTIONS'; +COMPUTE: 'COMPUTE'; +CONCATENATE: 'CONCATENATE'; +CONSTRAINT: 'CONSTRAINT'; +COST: 'COST'; +CREATE: 'CREATE'; +CROSS: 'CROSS'; +CUBE: 'CUBE'; +CURRENT: 'CURRENT'; +CURRENT_DATE: 'CURRENT_DATE'; +CURRENT_TIME: 'CURRENT_TIME'; +CURRENT_TIMESTAMP: 'CURRENT_TIMESTAMP'; +CURRENT_USER: 'CURRENT_USER'; +DAY: 'DAY'; +DATA: 'DATA'; +DATABASE: 'DATABASE'; +DATABASES: 'DATABASES' | 'SCHEMAS'; +DBPROPERTIES: 'DBPROPERTIES'; +DEFINED: 'DEFINED'; +DELETE: 'DELETE'; +DELIMITED: 'DELIMITED'; +DESC: 'DESC'; +DESCRIBE: 'DESCRIBE'; +DFS: 'DFS'; +DIRECTORIES: 'DIRECTORIES'; +DIRECTORY: 'DIRECTORY'; +DISTINCT: 'DISTINCT'; +DISTRIBUTE: 'DISTRIBUTE'; +DIV: 'DIV'; +DROP: 'DROP'; +ELSE: 'ELSE'; +END: 'END'; +ESCAPE: 'ESCAPE'; +ESCAPED: 'ESCAPED'; +EXCEPT: 'EXCEPT'; +EXCHANGE: 'EXCHANGE'; +EXISTS: 'EXISTS'; +EXPLAIN: 'EXPLAIN'; +EXPORT: 'EXPORT'; +EXTENDED: 'EXTENDED'; +EXTERNAL: 'EXTERNAL'; +EXTRACT: 'EXTRACT'; +FALSE: 'FALSE'; +FETCH: 'FETCH'; +FIELDS: 'FIELDS'; +FILTER: 'FILTER'; +FILEFORMAT: 'FILEFORMAT'; +FIRST: 'FIRST'; +FOLLOWING: 'FOLLOWING'; +FOR: 'FOR'; +FOREIGN: 'FOREIGN'; +FORMAT: 'FORMAT'; +FORMATTED: 'FORMATTED'; +FROM: 'FROM'; +FULL: 'FULL'; +FUNCTION: 'FUNCTION'; +FUNCTIONS: 'FUNCTIONS'; +GLOBAL: 'GLOBAL'; +GRANT: 'GRANT'; +GROUP: 'GROUP'; +GROUPING: 'GROUPING'; +HAVING: 'HAVING'; +HOUR: 'HOUR'; +IF: 'IF'; +IGNORE: 'IGNORE'; +IMPORT: 'IMPORT'; +IN: 'IN'; +INDEX: 'INDEX'; +INDEXES: 'INDEXES'; +INNER: 'INNER'; +INPATH: 'INPATH'; +INPUTFORMAT: 'INPUTFORMAT'; +INSERT: 'INSERT'; +INTERSECT: 'INTERSECT'; +INTERVAL: 'INTERVAL'; +INTO: 'INTO'; +IS: 'IS'; +ITEMS: 'ITEMS'; +JOIN: 'JOIN'; +KEYS: 'KEYS'; +LAST: 'LAST'; +LATERAL: 'LATERAL'; +LAZY: 'LAZY'; +LEADING: 'LEADING'; +LEFT: 'LEFT'; +LIKE: 'LIKE'; +LIMIT: 'LIMIT'; +LINES: 'LINES'; +LIST: 'LIST'; +LOAD: 'LOAD'; +LOCAL: 'LOCAL'; +LOCATION: 'LOCATION'; +LOCK: 'LOCK'; +LOCKS: 'LOCKS'; +LOGICAL: 'LOGICAL'; +MACRO: 'MACRO'; +MAP: 'MAP'; +MATCHED: 'MATCHED'; +MERGE: 'MERGE'; +MINUTE: 'MINUTE'; +MONTH: 'MONTH'; +MSCK: 'MSCK'; +NAMESPACE: 'NAMESPACE'; +NAMESPACES: 'NAMESPACES'; +NATURAL: 'NATURAL'; +NO: 'NO'; +NOT: 'NOT' | '!'; +NULL: 'NULL'; +NULLS: 'NULLS'; +OF: 'OF'; +ON: 'ON'; +ONLY: 'ONLY'; +OPTION: 'OPTION'; +OPTIONS: 'OPTIONS'; +OR: 'OR'; +ORDER: 'ORDER'; +OUT: 'OUT'; +OUTER: 'OUTER'; +OUTPUTFORMAT: 'OUTPUTFORMAT'; +OVER: 'OVER'; +OVERLAPS: 'OVERLAPS'; +OVERLAY: 'OVERLAY'; +OVERWRITE: 'OVERWRITE'; +PARTITION: 'PARTITION'; +PARTITIONED: 'PARTITIONED'; +PARTITIONS: 'PARTITIONS'; +PERCENTLIT: 'PERCENT'; +PIVOT: 'PIVOT'; +PLACING: 'PLACING'; +POSITION: 'POSITION'; +PRECEDING: 'PRECEDING'; +PRIMARY: 'PRIMARY'; +PRINCIPALS: 'PRINCIPALS'; +PROPERTIES: 'PROPERTIES'; +PURGE: 'PURGE'; +QUERY: 'QUERY'; +RANGE: 'RANGE'; +RECORDREADER: 'RECORDREADER'; +RECORDWRITER: 'RECORDWRITER'; +RECOVER: 'RECOVER'; +REDUCE: 'REDUCE'; +REFERENCES: 'REFERENCES'; +REFRESH: 'REFRESH'; +RENAME: 'RENAME'; +REPAIR: 'REPAIR'; +REPLACE: 'REPLACE'; +RESET: 'RESET'; +RESPECT: 'RESPECT'; +RESTRICT: 'RESTRICT'; +REVOKE: 'REVOKE'; +RIGHT: 'RIGHT'; +RLIKE: 'RLIKE' | 'REGEXP'; +ROLE: 'ROLE'; +ROLES: 'ROLES'; +ROLLBACK: 'ROLLBACK'; +ROLLUP: 'ROLLUP'; +ROW: 'ROW'; +ROWS: 'ROWS'; +SECOND: 'SECOND'; +SCHEMA: 'SCHEMA'; +SELECT: 'SELECT'; +SEMI: 'SEMI'; +SEPARATED: 'SEPARATED'; +SERDE: 'SERDE'; +SERDEPROPERTIES: 'SERDEPROPERTIES'; +SESSION_USER: 'SESSION_USER'; +SET: 'SET'; +SETMINUS: 'MINUS'; +SETS: 'SETS'; +SHOW: 'SHOW'; +SKEWED: 'SKEWED'; +SOME: 'SOME'; +SORT: 'SORT'; +SORTED: 'SORTED'; +START: 'START'; +STATISTICS: 'STATISTICS'; +STORED: 'STORED'; +STRATIFY: 'STRATIFY'; +STRUCT: 'STRUCT'; +SUBSTR: 'SUBSTR'; +SUBSTRING: 'SUBSTRING'; +SYNC: 'SYNC'; +TABLE: 'TABLE'; +TABLES: 'TABLES'; +TABLESAMPLE: 'TABLESAMPLE'; +TBLPROPERTIES: 'TBLPROPERTIES'; +TEMPORARY: 'TEMPORARY' | 'TEMP'; +TERMINATED: 'TERMINATED'; +THEN: 'THEN'; +TIME: 'TIME'; +TO: 'TO'; +TOUCH: 'TOUCH'; +TRAILING: 'TRAILING'; +TRANSACTION: 'TRANSACTION'; +TRANSACTIONS: 'TRANSACTIONS'; +TRANSFORM: 'TRANSFORM'; +TRIM: 'TRIM'; +TRUE: 'TRUE'; +TRUNCATE: 'TRUNCATE'; +TRY_CAST: 'TRY_CAST'; +TYPE: 'TYPE'; +UNARCHIVE: 'UNARCHIVE'; +UNBOUNDED: 'UNBOUNDED'; +UNCACHE: 'UNCACHE'; +UNION: 'UNION'; +UNIQUE: 'UNIQUE'; +UNKNOWN: 'UNKNOWN'; +UNLOCK: 'UNLOCK'; +UNSET: 'UNSET'; +UPDATE: 'UPDATE'; +USE: 'USE'; +USER: 'USER'; +USING: 'USING'; +VALUES: 'VALUES'; +VIEW: 'VIEW'; +VIEWS: 'VIEWS'; +WHEN: 'WHEN'; +WHERE: 'WHERE'; +WINDOW: 'WINDOW'; +WITH: 'WITH'; +YEAR: 'YEAR'; +ZONE: 'ZONE'; + +SYSTEM_VERSION: 'SYSTEM_VERSION'; +VERSION: 'VERSION'; +SYSTEM_TIME: 'SYSTEM_TIME'; +TIMESTAMP: 'TIMESTAMP'; +//--SPARK-KEYWORD-LIST-END +//============================ +// End of the keywords list +//============================ +LEFT_PAREN: '('; +RIGHT_PAREN: ')'; +COMMA: ','; +DOT: '.'; + +EQ : '=' | '=='; +NSEQ: '<=>'; +NEQ : '<>'; +NEQJ: '!='; +LT : '<'; +LTE : '<=' | '!>'; +GT : '>'; +GTE : '>=' | '!<'; + +PLUS: '+'; +MINUS: '-'; +ASTERISK: '*'; +SLASH: '/'; +PERCENT: '%'; +TILDE: '~'; +AMPERSAND: '&'; +PIPE: '|'; +CONCAT_PIPE: '||'; +HAT: '^'; + +STRING + : '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' + | '"' ( ~('"'|'\\') | ('\\' .) )* '"' + ; + +BIGINT_LITERAL + : DIGIT+ 'L' + ; + +SMALLINT_LITERAL + : DIGIT+ 'S' + ; + +TINYINT_LITERAL + : DIGIT+ 'Y' + ; + +INTEGER_VALUE + : DIGIT+ + ; + +EXPONENT_VALUE + : DIGIT+ EXPONENT + | DECIMAL_DIGITS EXPONENT {isValidDecimal()}? + ; + +DECIMAL_VALUE + : DECIMAL_DIGITS {isValidDecimal()}? + ; + +FLOAT_LITERAL + : DIGIT+ EXPONENT? 'F' + | DECIMAL_DIGITS EXPONENT? 'F' {isValidDecimal()}? + ; + +DOUBLE_LITERAL + : DIGIT+ EXPONENT? 'D' + | DECIMAL_DIGITS EXPONENT? 'D' {isValidDecimal()}? + ; + +BIGDECIMAL_LITERAL + : DIGIT+ EXPONENT? 'BD' + | DECIMAL_DIGITS EXPONENT? 'BD' {isValidDecimal()}? + ; + +IDENTIFIER + : (LETTER | DIGIT | '_')+ + ; + +BACKQUOTED_IDENTIFIER + : '`' ( ~'`' | '``' )* '`' + ; + +fragment DECIMAL_DIGITS + : DIGIT+ '.' DIGIT* + | '.' DIGIT+ + ; + +fragment EXPONENT + : 'E' [+-]? DIGIT+ + ; + +fragment DIGIT + : [0-9] + ; + +fragment LETTER + : [A-Z] + ; + +SIMPLE_COMMENT + : '--' ('\\\n' | ~[\r\n])* '\r'? '\n'? -> channel(HIDDEN) + ; + +BRACKETED_COMMENT + : '/*' {!isHint()}? (BRACKETED_COMMENT|.)*? '*/' -> channel(HIDDEN) + ; + +WS + : [ \r\n\t]+ -> channel(HIDDEN) + ; + +// Catch-all for anything we can't recognize. +// We use this to be able to ignore and recover all the text +// when splitting statements with DelimiterLexer +UNRECOGNIZED + : . + ; diff --git a/hudi-spark-datasource/hudi-spark3.5.x/src/main/antlr4/org/apache/hudi/spark/sql/parser/HoodieSqlBase.g4 b/hudi-spark-datasource/hudi-spark3.5.x/src/main/antlr4/org/apache/hudi/spark/sql/parser/HoodieSqlBase.g4 new file mode 100644 index 0000000000000..ddbecfefc760d --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.5.x/src/main/antlr4/org/apache/hudi/spark/sql/parser/HoodieSqlBase.g4 @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +grammar HoodieSqlBase; + +import SqlBase; + +singleStatement + : statement EOF + ; + +statement + : query #queryStatement + | ctes? dmlStatementNoWith #dmlStatement + | createTableHeader ('(' colTypeList ')')? tableProvider? + createTableClauses + (AS? query)? #createTable + | CREATE INDEX (IF NOT EXISTS)? identifier ON TABLE? + tableIdentifier (USING indexType=identifier)? + LEFT_PAREN columns=multipartIdentifierPropertyList RIGHT_PAREN + (OPTIONS indexOptions=propertyList)? #createIndex + | DROP INDEX (IF EXISTS)? identifier ON TABLE? tableIdentifier #dropIndex + | SHOW INDEXES (FROM | IN) TABLE? tableIdentifier #showIndexes + | REFRESH INDEX identifier ON TABLE? tableIdentifier #refreshIndex + | .*? #passThrough + ; diff --git a/hudi-spark-datasource/hudi-spark3.5.x/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/hudi-spark-datasource/hudi-spark3.5.x/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister new file mode 100644 index 0000000000000..c8dd99a95c27a --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.5.x/src/main/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -0,0 +1,19 @@ + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +org.apache.hudi.Spark32PlusDefaultSource \ No newline at end of file diff --git a/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/hudi/Spark35HoodieFileScanRDD.scala b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/hudi/Spark35HoodieFileScanRDD.scala new file mode 100644 index 0000000000000..9ab3c04605d5f --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/hudi/Spark35HoodieFileScanRDD.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hudi + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, PartitionedFile} +import org.apache.spark.sql.types.StructType + +class Spark35HoodieFileScanRDD(@transient private val sparkSession: SparkSession, + read: PartitionedFile => Iterator[InternalRow], + @transient filePartitions: Seq[FilePartition], + readDataSchema: StructType, + metadataColumns: Seq[AttributeReference] = Seq.empty) + extends FileScanRDD(sparkSession, read, filePartitions, readDataSchema, metadataColumns) + with HoodieUnsafeRDD { + + override final def collect(): Array[InternalRow] = super[HoodieUnsafeRDD].collect() +} diff --git a/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/HoodieSpark35CatalogUtils.scala b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/HoodieSpark35CatalogUtils.scala new file mode 100644 index 0000000000000..b97f94e7de074 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/HoodieSpark35CatalogUtils.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.connector.expressions.{BucketTransform, NamedReference, Transform} + +object HoodieSpark35CatalogUtils extends HoodieSpark3CatalogUtils { + + override def unapplyBucketTransform(t: Transform): Option[(Int, Seq[NamedReference], Seq[NamedReference])] = + t match { + case BucketTransform(numBuckets, refs, sortedRefs) => Some(numBuckets, refs, sortedRefs) + case _ => None + } + +} diff --git a/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/HoodieSpark35CatalystExpressionUtils.scala b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/HoodieSpark35CatalystExpressionUtils.scala new file mode 100644 index 0000000000000..ae4803dc8b91c --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/HoodieSpark35CatalystExpressionUtils.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.HoodieSparkTypeUtils.isCastPreservingOrdering +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{Add, Attribute, AttributeReference, AttributeSet, BitwiseOr, Cast, DateAdd, DateDiff, DateFormatClass, DateSub, Divide, EvalMode, Exp, Expm1, Expression, FromUTCTimestamp, FromUnixTime, Log, Log10, Log1p, Log2, Lower, Multiply, ParseToDate, ParseToTimestamp, PredicateHelper, ShiftLeft, ShiftRight, ToUTCTimestamp, ToUnixTimestamp, Upper} +import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.types.{DataType, StructType} + +object HoodieSpark35CatalystExpressionUtils extends HoodieSpark3CatalystExpressionUtils with PredicateHelper { + + override def getEncoder(schema: StructType): ExpressionEncoder[Row] = { + ExpressionEncoder.apply(schema).resolveAndBind() + } + + override def normalizeExprs(exprs: Seq[Expression], attributes: Seq[Attribute]): Seq[Expression] = { + DataSourceStrategy.normalizeExprs(exprs, attributes) + } + + override def extractPredicatesWithinOutputSet(condition: Expression, outputSet: AttributeSet): Option[Expression] = { + super[PredicateHelper].extractPredicatesWithinOutputSet(condition, outputSet) + } + + override def matchCast(expr: Expression): Option[(Expression, DataType, Option[String])] = { + expr match { + case Cast(child, dataType, timeZoneId, _) => Some((child, dataType, timeZoneId)) + case _ => None + } + } + + override def tryMatchAttributeOrderingPreservingTransformation(expr: Expression): Option[AttributeReference] = { + expr match { + case OrderPreservingTransformation(attrRef) => Some(attrRef) + case _ => None + } + } + + def canUpCast(fromType: DataType, toType: DataType): Boolean = + Cast.canUpCast(fromType, toType) + + override def unapplyCastExpression(expr: Expression): Option[(Expression, DataType, Option[String], Boolean)] = + expr match { + case Cast(castedExpr, dataType, timeZoneId, ansiEnabled) => + Some((castedExpr, dataType, timeZoneId, if (ansiEnabled == EvalMode.ANSI) true else false)) + case _ => None + } + + private object OrderPreservingTransformation { + def unapply(expr: Expression): Option[AttributeReference] = { + expr match { + // Date/Time Expressions + case DateFormatClass(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef) + case DateAdd(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case DateSub(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case DateDiff(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case DateDiff(_, OrderPreservingTransformation(attrRef)) => Some(attrRef) + case FromUnixTime(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef) + case FromUTCTimestamp(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case ParseToDate(OrderPreservingTransformation(attrRef), _, _, _) => Some(attrRef) + case ParseToTimestamp(OrderPreservingTransformation(attrRef), _, _, _, _) => Some(attrRef) + case ToUnixTimestamp(OrderPreservingTransformation(attrRef), _, _, _) => Some(attrRef) + case ToUTCTimestamp(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + + // String Expressions + case Lower(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case Upper(OrderPreservingTransformation(attrRef)) => Some(attrRef) + // Left API change: Improve RuntimeReplaceable + // https://issues.apache.org/jira/browse/SPARK-38240 + case org.apache.spark.sql.catalyst.expressions.Left(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + + // Math Expressions + // Binary + case Add(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef) + case Add(_, OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case Multiply(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef) + case Multiply(_, OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case Divide(OrderPreservingTransformation(attrRef), _, _) => Some(attrRef) + case BitwiseOr(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case BitwiseOr(_, OrderPreservingTransformation(attrRef)) => Some(attrRef) + // Unary + case Exp(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case Expm1(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case Log(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case Log10(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case Log1p(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case Log2(OrderPreservingTransformation(attrRef)) => Some(attrRef) + case ShiftLeft(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + case ShiftRight(OrderPreservingTransformation(attrRef), _) => Some(attrRef) + + // Other + case cast @ Cast(OrderPreservingTransformation(attrRef), _, _, _) + if isCastPreservingOrdering(cast.child.dataType, cast.dataType) => Some(attrRef) + + // Identity transformation + case attrRef: AttributeReference => Some(attrRef) + // No match + case _ => None + } + } + } +} diff --git a/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/HoodieSpark35CatalystPlanUtils.scala b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/HoodieSpark35CatalystPlanUtils.scala new file mode 100644 index 0000000000000..3ef6abdf8e158 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/HoodieSpark35CatalystPlanUtils.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.hudi.SparkHoodieTableFileIndex +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.{AnalysisErrorAt, ResolvedTable} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, ProjectionOverSchema} +import org.apache.spark.sql.catalyst.planning.ScanOperation +import org.apache.spark.sql.catalyst.plans.logical.{CreateIndex, DropIndex, LogicalPlan, MergeIntoTable, Project, RefreshIndex, ShowIndexes} +import org.apache.spark.sql.connector.catalog.{Identifier, Table, TableCatalog} +import org.apache.spark.sql.execution.command.RepairTableCommand +import org.apache.spark.sql.execution.datasources.parquet.NewHoodieParquetFileFormat +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.types.StructType + +object HoodieSpark35CatalystPlanUtils extends HoodieSpark3CatalystPlanUtils { + + def unapplyResolvedTable(plan: LogicalPlan): Option[(TableCatalog, Identifier, Table)] = + plan match { + case ResolvedTable(catalog, identifier, table, _) => Some((catalog, identifier, table)) + case _ => None + } + + override def unapplyMergeIntoTable(plan: LogicalPlan): Option[(LogicalPlan, LogicalPlan, Expression)] = { + plan match { + case MergeIntoTable(targetTable, sourceTable, mergeCondition, _, _, _) => + Some((targetTable, sourceTable, mergeCondition)) + case _ => None + } + } + + override def applyNewHoodieParquetFileFormatProjection(plan: LogicalPlan): LogicalPlan = { + plan match { + case s@ScanOperation(_, _, _, + l@LogicalRelation(fs: HadoopFsRelation, _, _, _)) if fs.fileFormat.isInstanceOf[NewHoodieParquetFileFormat] && !fs.fileFormat.asInstanceOf[NewHoodieParquetFileFormat].isProjected => + fs.fileFormat.asInstanceOf[NewHoodieParquetFileFormat].isProjected = true + Project(l.resolve(fs.location.asInstanceOf[SparkHoodieTableFileIndex].schema, fs.sparkSession.sessionState.analyzer.resolver), s) + case _ => plan + } + } + + override def projectOverSchema(schema: StructType, output: AttributeSet): ProjectionOverSchema = + ProjectionOverSchema(schema, output) + + override def isRepairTable(plan: LogicalPlan): Boolean = { + plan.isInstanceOf[RepairTableCommand] + } + + override def getRepairTableChildren(plan: LogicalPlan): Option[(TableIdentifier, Boolean, Boolean, String)] = { + plan match { + case rtc: RepairTableCommand => + Some((rtc.tableName, rtc.enableAddPartitions, rtc.enableDropPartitions, rtc.cmd)) + case _ => + None + } + } + + override def failAnalysisForMIT(a: Attribute, cols: String): Unit = { + a.failAnalysis( + errorClass = "_LEGACY_ERROR_TEMP_2309", + messageParameters = Map( + "sqlExpr" -> a.sql, + "cols" -> cols)) + } + + override def unapplyCreateIndex(plan: LogicalPlan): Option[(LogicalPlan, String, String, Boolean, Seq[(Seq[String], Map[String, String])], Map[String, String])] = { + plan match { + case ci@CreateIndex(table, indexName, indexType, ignoreIfExists, columns, properties) => + Some((table, indexName, indexType, ignoreIfExists, columns.map(col => (col._1.name, col._2)), properties)) + case _ => + None + } + } + + override def unapplyDropIndex(plan: LogicalPlan): Option[(LogicalPlan, String, Boolean)] = { + plan match { + case ci@DropIndex(table, indexName, ignoreIfNotExists) => + Some((table, indexName, ignoreIfNotExists)) + case _ => + None + } + } + + override def unapplyShowIndexes(plan: LogicalPlan): Option[(LogicalPlan, Seq[Attribute])] = { + plan match { + case ci@ShowIndexes(table, output) => + Some((table, output)) + case _ => + None + } + } + + override def unapplyRefreshIndex(plan: LogicalPlan): Option[(LogicalPlan, String)] = { + plan match { + case ci@RefreshIndex(table, indexName) => + Some((table, indexName)) + case _ => + None + } + } +} diff --git a/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/HoodieSpark35SchemaUtils.scala b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/HoodieSpark35SchemaUtils.scala new file mode 100644 index 0000000000000..8c657d91fb031 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/HoodieSpark35SchemaUtils.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.SchemaUtils + +/** + * Utils on schema for Spark 3.4+. + */ +object HoodieSpark35SchemaUtils extends HoodieSchemaUtils { + override def checkColumnNameDuplication(columnNames: Seq[String], + colType: String, + caseSensitiveAnalysis: Boolean): Unit = { + SchemaUtils.checkColumnNameDuplication(columnNames, caseSensitiveAnalysis) + } + + override def toAttributes(struct: StructType): Seq[Attribute] = { + DataTypeUtils.toAttributes(struct) + } +} diff --git a/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/adapter/Spark3_5Adapter.scala b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/adapter/Spark3_5Adapter.scala new file mode 100644 index 0000000000000..12beba9ba3221 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/adapter/Spark3_5Adapter.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.adapter + +import org.apache.avro.Schema +import org.apache.hadoop.fs.FileStatus +import org.apache.hudi.Spark35HoodieFileScanRDD +import org.apache.spark.sql.avro._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases +import org.apache.spark.sql.catalyst.catalog.CatalogTable +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.catalyst.util.METADATA_COL_ATTR_KEY +import org.apache.spark.sql.connector.catalog.V2TableWithV1Fallback +import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, Spark35LegacyHoodieParquetFileFormat} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.hudi.analysis.TableValuedFunctions +import org.apache.spark.sql.parser.{HoodieExtendedParserInterface, HoodieSpark3_5ExtendedSqlParser} +import org.apache.spark.sql.types.{DataType, Metadata, MetadataBuilder, StructType} +import org.apache.spark.sql.vectorized.ColumnarBatchRow +import org.apache.spark.sql._ +import org.apache.spark.storage.StorageLevel +import org.apache.spark.storage.StorageLevel._ + +/** + * Implementation of [[SparkAdapter]] for Spark 3.5.x branch + */ +class Spark3_5Adapter extends BaseSpark3Adapter { + + override def resolveHoodieTable(plan: LogicalPlan): Option[CatalogTable] = { + super.resolveHoodieTable(plan).orElse { + EliminateSubqueryAliases(plan) match { + // First, we need to weed out unresolved plans + case plan if !plan.resolved => None + // NOTE: When resolving Hudi table we allow [[Filter]]s and [[Project]]s be applied + // on top of it + case PhysicalOperation(_, _, DataSourceV2Relation(v2: V2TableWithV1Fallback, _, _, _, _)) if isHoodieTable(v2.v1Table) => + Some(v2.v1Table) + case _ => None + } + } + } + + override def isColumnarBatchRow(r: InternalRow): Boolean = r.isInstanceOf[ColumnarBatchRow] + + def createCatalystMetadataForMetaField: Metadata = + new MetadataBuilder() + .putBoolean(METADATA_COL_ATTR_KEY, value = true) + .build() + + override def getCatalogUtils: HoodieSpark3CatalogUtils = HoodieSpark35CatalogUtils + + override def getCatalystExpressionUtils: HoodieCatalystExpressionUtils = HoodieSpark35CatalystExpressionUtils + + override def getCatalystPlanUtils: HoodieCatalystPlansUtils = HoodieSpark35CatalystPlanUtils + + override def getSchemaUtils: HoodieSchemaUtils = HoodieSpark35SchemaUtils + + override def getSparkPartitionedFileUtils: HoodieSparkPartitionedFileUtils = HoodieSpark35PartitionedFileUtils + + override def createAvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean): HoodieAvroSerializer = + new HoodieSpark3_5AvroSerializer(rootCatalystType, rootAvroType, nullable) + + override def createAvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType): HoodieAvroDeserializer = + new HoodieSpark3_5AvroDeserializer(rootAvroType, rootCatalystType) + + override def createExtendedSparkParser(spark: SparkSession, delegate: ParserInterface): HoodieExtendedParserInterface = + new HoodieSpark3_5ExtendedSqlParser(spark, delegate) + + override def createLegacyHoodieParquetFileFormat(appendPartitionValues: Boolean): Option[ParquetFileFormat] = { + Some(new Spark35LegacyHoodieParquetFileFormat(appendPartitionValues)) + } + + override def createHoodieFileScanRDD(sparkSession: SparkSession, + readFunction: PartitionedFile => Iterator[InternalRow], + filePartitions: Seq[FilePartition], + readDataSchema: StructType, + metadataColumns: Seq[AttributeReference] = Seq.empty): FileScanRDD = { + new Spark35HoodieFileScanRDD(sparkSession, readFunction, filePartitions, readDataSchema, metadataColumns) + } + + override def extractDeleteCondition(deleteFromTable: Command): Expression = { + deleteFromTable.asInstanceOf[DeleteFromTable].condition + } + + override def injectTableFunctions(extensions: SparkSessionExtensions): Unit = { + TableValuedFunctions.funcs.foreach(extensions.injectTableFunction) + } + + /** + * Converts instance of [[StorageLevel]] to a corresponding string + */ + override def convertStorageLevelToString(level: StorageLevel): String = level match { + case NONE => "NONE" + case DISK_ONLY => "DISK_ONLY" + case DISK_ONLY_2 => "DISK_ONLY_2" + case DISK_ONLY_3 => "DISK_ONLY_3" + case MEMORY_ONLY => "MEMORY_ONLY" + case MEMORY_ONLY_2 => "MEMORY_ONLY_2" + case MEMORY_ONLY_SER => "MEMORY_ONLY_SER" + case MEMORY_ONLY_SER_2 => "MEMORY_ONLY_SER_2" + case MEMORY_AND_DISK => "MEMORY_AND_DISK" + case MEMORY_AND_DISK_2 => "MEMORY_AND_DISK_2" + case MEMORY_AND_DISK_SER => "MEMORY_AND_DISK_SER" + case MEMORY_AND_DISK_SER_2 => "MEMORY_AND_DISK_SER_2" + case OFF_HEAP => "OFF_HEAP" + case _ => throw new IllegalArgumentException(s"Invalid StorageLevel: $level") + } +} diff --git a/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala new file mode 100644 index 0000000000000..583e2da0e65a9 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -0,0 +1,495 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.math.BigDecimal +import java.nio.ByteBuffer +import scala.collection.JavaConverters._ +import org.apache.avro.{LogicalTypes, Schema, SchemaBuilder} +import org.apache.avro.Conversions.DecimalConversion +import org.apache.avro.LogicalTypes.{LocalTimestampMicros, LocalTimestampMillis, TimestampMicros, TimestampMillis} +import org.apache.avro.Schema.Type._ +import org.apache.avro.generic._ +import org.apache.avro.util.Utf8 +import org.apache.spark.sql.avro.AvroDeserializer.{RebaseSpec, createDateRebaseFuncInRead, createTimestampRebaseFuncInRead} +import org.apache.spark.sql.avro.AvroUtils.{AvroMatchedField, toFieldStr} +import org.apache.spark.sql.catalyst.{InternalRow, NoopFilters, StructFilters} +import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData, RebaseDateTime} +import org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_DAY +import org.apache.spark.sql.execution.datasources.DataSourceUtils +import org.apache.spark.sql.internal.LegacyBehaviorPolicy +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +import java.util.TimeZone + +/** + * A deserializer to deserialize data in avro format to data in catalyst format. + * + * NOTE: This code is borrowed from Spark 3.3.0 + * This code is borrowed, so that we can better control compatibility w/in Spark minor + * branches (3.2.x, 3.1.x, etc) + * + * PLEASE REFRAIN MAKING ANY CHANGES TO THIS CODE UNLESS ABSOLUTELY NECESSARY + */ +private[sql] class AvroDeserializer(rootAvroType: Schema, + rootCatalystType: DataType, + positionalFieldMatch: Boolean, + datetimeRebaseSpec: RebaseSpec, + filters: StructFilters) { + + def this(rootAvroType: Schema, + rootCatalystType: DataType, + datetimeRebaseMode: String) = { + this( + rootAvroType, + rootCatalystType, + positionalFieldMatch = false, + RebaseSpec(LegacyBehaviorPolicy.withName(datetimeRebaseMode)), + new NoopFilters) + } + + private lazy val decimalConversions = new DecimalConversion() + + private val dateRebaseFunc = createDateRebaseFuncInRead(datetimeRebaseSpec.mode, "Avro") + + private val timestampRebaseFunc = createTimestampRebaseFuncInRead(datetimeRebaseSpec, "Avro") + + private val converter: Any => Option[Any] = try { + rootCatalystType match { + // A shortcut for empty schema. + case st: StructType if st.isEmpty => + (_: Any) => Some(InternalRow.empty) + + case st: StructType => + val resultRow = new SpecificInternalRow(st.map(_.dataType)) + val fieldUpdater = new RowUpdater(resultRow) + val applyFilters = filters.skipRow(resultRow, _) + val writer = getRecordWriter(rootAvroType, st, Nil, Nil, applyFilters) + (data: Any) => { + val record = data.asInstanceOf[GenericRecord] + val skipRow = writer(fieldUpdater, record) + if (skipRow) None else Some(resultRow) + } + + case _ => + val tmpRow = new SpecificInternalRow(Seq(rootCatalystType)) + val fieldUpdater = new RowUpdater(tmpRow) + val writer = newWriter(rootAvroType, rootCatalystType, Nil, Nil) + (data: Any) => { + writer(fieldUpdater, 0, data) + Some(tmpRow.get(0, rootCatalystType)) + } + } + } catch { + case ise: IncompatibleSchemaException => throw new IncompatibleSchemaException( + s"Cannot convert Avro type $rootAvroType to SQL type ${rootCatalystType.sql}.", ise) + } + + def deserialize(data: Any): Option[Any] = converter(data) + + /** + * Creates a writer to write avro values to Catalyst values at the given ordinal with the given + * updater. + */ + private def newWriter(avroType: Schema, + catalystType: DataType, + avroPath: Seq[String], + catalystPath: Seq[String]): (CatalystDataUpdater, Int, Any) => Unit = { + val errorPrefix = s"Cannot convert Avro ${toFieldStr(avroPath)} to " + + s"SQL ${toFieldStr(catalystPath)} because " + val incompatibleMsg = errorPrefix + + s"schema is incompatible (avroType = $avroType, sqlType = ${catalystType.sql})" + + (avroType.getType, catalystType) match { + case (NULL, NullType) => (updater, ordinal, _) => + updater.setNullAt(ordinal) + + // TODO: we can avoid boxing if future version of avro provide primitive accessors. + case (BOOLEAN, BooleanType) => (updater, ordinal, value) => + updater.setBoolean(ordinal, value.asInstanceOf[Boolean]) + + case (INT, IntegerType) => (updater, ordinal, value) => + updater.setInt(ordinal, value.asInstanceOf[Int]) + + case (INT, DateType) => (updater, ordinal, value) => + updater.setInt(ordinal, dateRebaseFunc(value.asInstanceOf[Int])) + + case (LONG, LongType) => (updater, ordinal, value) => + updater.setLong(ordinal, value.asInstanceOf[Long]) + + case (LONG, TimestampType) => avroType.getLogicalType match { + // For backward compatibility, if the Avro type is Long and it is not logical type + // (the `null` case), the value is processed as timestamp type with millisecond precision. + case null | _: TimestampMillis => (updater, ordinal, value) => + val millis = value.asInstanceOf[Long] + val micros = DateTimeUtils.millisToMicros(millis) + updater.setLong(ordinal, timestampRebaseFunc(micros)) + case _: TimestampMicros => (updater, ordinal, value) => + val micros = value.asInstanceOf[Long] + updater.setLong(ordinal, timestampRebaseFunc(micros)) + case other => throw new IncompatibleSchemaException(errorPrefix + + s"Avro logical type $other cannot be converted to SQL type ${TimestampType.sql}.") + } + + case (LONG, TimestampNTZType) => avroType.getLogicalType match { + // To keep consistent with TimestampType, if the Avro type is Long and it is not + // logical type (the `null` case), the value is processed as TimestampNTZ + // with millisecond precision. + case null | _: LocalTimestampMillis => (updater, ordinal, value) => + val millis = value.asInstanceOf[Long] + val micros = DateTimeUtils.millisToMicros(millis) + updater.setLong(ordinal, micros) + case _: LocalTimestampMicros => (updater, ordinal, value) => + val micros = value.asInstanceOf[Long] + updater.setLong(ordinal, micros) + case other => throw new IncompatibleSchemaException(errorPrefix + + s"Avro logical type $other cannot be converted to SQL type ${TimestampNTZType.sql}.") + } + + // Before we upgrade Avro to 1.8 for logical type support, spark-avro converts Long to Date. + // For backward compatibility, we still keep this conversion. + case (LONG, DateType) => (updater, ordinal, value) => + updater.setInt(ordinal, (value.asInstanceOf[Long] / MILLIS_PER_DAY).toInt) + + case (FLOAT, FloatType) => (updater, ordinal, value) => + updater.setFloat(ordinal, value.asInstanceOf[Float]) + + case (DOUBLE, DoubleType) => (updater, ordinal, value) => + updater.setDouble(ordinal, value.asInstanceOf[Double]) + + case (STRING, StringType) => (updater, ordinal, value) => + val str = value match { + case s: String => UTF8String.fromString(s) + case s: Utf8 => + val bytes = new Array[Byte](s.getByteLength) + System.arraycopy(s.getBytes, 0, bytes, 0, s.getByteLength) + UTF8String.fromBytes(bytes) + case s: GenericData.EnumSymbol => UTF8String.fromString(s.toString) + } + updater.set(ordinal, str) + + case (ENUM, StringType) => (updater, ordinal, value) => + updater.set(ordinal, UTF8String.fromString(value.toString)) + + case (FIXED, BinaryType) => (updater, ordinal, value) => + updater.set(ordinal, value.asInstanceOf[GenericFixed].bytes().clone()) + + case (BYTES, BinaryType) => (updater, ordinal, value) => + val bytes = value match { + case b: ByteBuffer => + val bytes = new Array[Byte](b.remaining) + b.get(bytes) + // Do not forget to reset the position + b.rewind() + bytes + case b: Array[Byte] => b + case other => + throw new RuntimeException(errorPrefix + s"$other is not a valid avro binary.") + } + updater.set(ordinal, bytes) + + case (FIXED, _: DecimalType) => (updater, ordinal, value) => + val d = avroType.getLogicalType.asInstanceOf[LogicalTypes.Decimal] + val bigDecimal = decimalConversions.fromFixed(value.asInstanceOf[GenericFixed], avroType, d) + val decimal = createDecimal(bigDecimal, d.getPrecision, d.getScale) + updater.setDecimal(ordinal, decimal) + + case (BYTES, _: DecimalType) => (updater, ordinal, value) => + val d = avroType.getLogicalType.asInstanceOf[LogicalTypes.Decimal] + val bigDecimal = decimalConversions.fromBytes(value.asInstanceOf[ByteBuffer], avroType, d) + val decimal = createDecimal(bigDecimal, d.getPrecision, d.getScale) + updater.setDecimal(ordinal, decimal) + + case (RECORD, st: StructType) => + // Avro datasource doesn't accept filters with nested attributes. See SPARK-32328. + // We can always return `false` from `applyFilters` for nested records. + val writeRecord = + getRecordWriter(avroType, st, avroPath, catalystPath, applyFilters = _ => false) + (updater, ordinal, value) => + val row = new SpecificInternalRow(st) + writeRecord(new RowUpdater(row), value.asInstanceOf[GenericRecord]) + updater.set(ordinal, row) + + case (ARRAY, ArrayType(elementType, containsNull)) => + val avroElementPath = avroPath :+ "element" + val elementWriter = newWriter(avroType.getElementType, elementType, + avroElementPath, catalystPath :+ "element") + (updater, ordinal, value) => + val collection = value.asInstanceOf[java.util.Collection[Any]] + val result = createArrayData(elementType, collection.size()) + val elementUpdater = new ArrayDataUpdater(result) + + var i = 0 + val iter = collection.iterator() + while (iter.hasNext) { + val element = iter.next() + if (element == null) { + if (!containsNull) { + throw new RuntimeException( + s"Array value at path ${toFieldStr(avroElementPath)} is not allowed to be null") + } else { + elementUpdater.setNullAt(i) + } + } else { + elementWriter(elementUpdater, i, element) + } + i += 1 + } + + updater.set(ordinal, result) + + case (MAP, MapType(keyType, valueType, valueContainsNull)) if keyType == StringType => + val keyWriter = newWriter(SchemaBuilder.builder().stringType(), StringType, + avroPath :+ "key", catalystPath :+ "key") + val valueWriter = newWriter(avroType.getValueType, valueType, + avroPath :+ "value", catalystPath :+ "value") + (updater, ordinal, value) => + val map = value.asInstanceOf[java.util.Map[AnyRef, AnyRef]] + val keyArray = createArrayData(keyType, map.size()) + val keyUpdater = new ArrayDataUpdater(keyArray) + val valueArray = createArrayData(valueType, map.size()) + val valueUpdater = new ArrayDataUpdater(valueArray) + val iter = map.entrySet().iterator() + var i = 0 + while (iter.hasNext) { + val entry = iter.next() + assert(entry.getKey != null) + keyWriter(keyUpdater, i, entry.getKey) + if (entry.getValue == null) { + if (!valueContainsNull) { + throw new RuntimeException( + s"Map value at path ${toFieldStr(avroPath :+ "value")} is not allowed to be null") + } else { + valueUpdater.setNullAt(i) + } + } else { + valueWriter(valueUpdater, i, entry.getValue) + } + i += 1 + } + + // The Avro map will never have null or duplicated map keys, it's safe to create a + // ArrayBasedMapData directly here. + updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray)) + + case (UNION, _) => + val allTypes = avroType.getTypes.asScala + val nonNullTypes = allTypes.filter(_.getType != NULL) + val nonNullAvroType = Schema.createUnion(nonNullTypes.asJava) + if (nonNullTypes.nonEmpty) { + if (nonNullTypes.length == 1) { + newWriter(nonNullTypes.head, catalystType, avroPath, catalystPath) + } else { + nonNullTypes.map(_.getType).toSeq match { + case Seq(a, b) if Set(a, b) == Set(INT, LONG) && catalystType == LongType => + (updater, ordinal, value) => value match { + case null => updater.setNullAt(ordinal) + case l: java.lang.Long => updater.setLong(ordinal, l) + case i: java.lang.Integer => updater.setLong(ordinal, i.longValue()) + } + + case Seq(a, b) if Set(a, b) == Set(FLOAT, DOUBLE) && catalystType == DoubleType => + (updater, ordinal, value) => value match { + case null => updater.setNullAt(ordinal) + case d: java.lang.Double => updater.setDouble(ordinal, d) + case f: java.lang.Float => updater.setDouble(ordinal, f.doubleValue()) + } + + case _ => + catalystType match { + case st: StructType if st.length == nonNullTypes.size => + val fieldWriters = nonNullTypes.zip(st.fields).map { + case (schema, field) => + newWriter(schema, field.dataType, avroPath, catalystPath :+ field.name) + }.toArray + (updater, ordinal, value) => { + val row = new SpecificInternalRow(st) + val fieldUpdater = new RowUpdater(row) + val i = GenericData.get().resolveUnion(nonNullAvroType, value) + fieldWriters(i)(fieldUpdater, i, value) + updater.set(ordinal, row) + } + + case _ => throw new IncompatibleSchemaException(incompatibleMsg) + } + } + } + } else { + (updater, ordinal, _) => updater.setNullAt(ordinal) + } + + case (INT, _: YearMonthIntervalType) => (updater, ordinal, value) => + updater.setInt(ordinal, value.asInstanceOf[Int]) + + case (LONG, _: DayTimeIntervalType) => (updater, ordinal, value) => + updater.setLong(ordinal, value.asInstanceOf[Long]) + + case _ => throw new IncompatibleSchemaException(incompatibleMsg) + } + } + + // TODO: move the following method in Decimal object on creating Decimal from BigDecimal? + private def createDecimal(decimal: BigDecimal, precision: Int, scale: Int): Decimal = { + if (precision <= Decimal.MAX_LONG_DIGITS) { + // Constructs a `Decimal` with an unscaled `Long` value if possible. + Decimal(decimal.unscaledValue().longValue(), precision, scale) + } else { + // Otherwise, resorts to an unscaled `BigInteger` instead. + Decimal(decimal, precision, scale) + } + } + + private def getRecordWriter( + avroType: Schema, + catalystType: StructType, + avroPath: Seq[String], + catalystPath: Seq[String], + applyFilters: Int => Boolean): (CatalystDataUpdater, GenericRecord) => Boolean = { + + val avroSchemaHelper = new AvroUtils.AvroSchemaHelper( + avroType, catalystType, avroPath, catalystPath, positionalFieldMatch) + + avroSchemaHelper.validateNoExtraCatalystFields(ignoreNullable = true) + // no need to validateNoExtraAvroFields since extra Avro fields are ignored + + val (validFieldIndexes, fieldWriters) = avroSchemaHelper.matchedFields.map { + case AvroMatchedField(catalystField, ordinal, avroField) => + val baseWriter = newWriter(avroField.schema(), catalystField.dataType, + avroPath :+ avroField.name, catalystPath :+ catalystField.name) + val fieldWriter = (fieldUpdater: CatalystDataUpdater, value: Any) => { + if (value == null) { + fieldUpdater.setNullAt(ordinal) + } else { + baseWriter(fieldUpdater, ordinal, value) + } + } + (avroField.pos(), fieldWriter) + }.toArray.unzip + + (fieldUpdater, record) => { + var i = 0 + var skipRow = false + while (i < validFieldIndexes.length && !skipRow) { + fieldWriters(i)(fieldUpdater, record.get(validFieldIndexes(i))) + skipRow = applyFilters(i) + i += 1 + } + skipRow + } + } + + private def createArrayData(elementType: DataType, length: Int): ArrayData = elementType match { + case BooleanType => UnsafeArrayData.fromPrimitiveArray(new Array[Boolean](length)) + case ByteType => UnsafeArrayData.fromPrimitiveArray(new Array[Byte](length)) + case ShortType => UnsafeArrayData.fromPrimitiveArray(new Array[Short](length)) + case IntegerType => UnsafeArrayData.fromPrimitiveArray(new Array[Int](length)) + case LongType => UnsafeArrayData.fromPrimitiveArray(new Array[Long](length)) + case FloatType => UnsafeArrayData.fromPrimitiveArray(new Array[Float](length)) + case DoubleType => UnsafeArrayData.fromPrimitiveArray(new Array[Double](length)) + case _ => new GenericArrayData(new Array[Any](length)) + } + + /** + * A base interface for updating values inside catalyst data structure like `InternalRow` and + * `ArrayData`. + */ + sealed trait CatalystDataUpdater { + def set(ordinal: Int, value: Any): Unit + + def setNullAt(ordinal: Int): Unit = set(ordinal, null) + def setBoolean(ordinal: Int, value: Boolean): Unit = set(ordinal, value) + def setByte(ordinal: Int, value: Byte): Unit = set(ordinal, value) + def setShort(ordinal: Int, value: Short): Unit = set(ordinal, value) + def setInt(ordinal: Int, value: Int): Unit = set(ordinal, value) + def setLong(ordinal: Int, value: Long): Unit = set(ordinal, value) + def setDouble(ordinal: Int, value: Double): Unit = set(ordinal, value) + def setFloat(ordinal: Int, value: Float): Unit = set(ordinal, value) + def setDecimal(ordinal: Int, value: Decimal): Unit = set(ordinal, value) + } + + final class RowUpdater(row: InternalRow) extends CatalystDataUpdater { + override def set(ordinal: Int, value: Any): Unit = row.update(ordinal, value) + + override def setNullAt(ordinal: Int): Unit = row.setNullAt(ordinal) + override def setBoolean(ordinal: Int, value: Boolean): Unit = row.setBoolean(ordinal, value) + override def setByte(ordinal: Int, value: Byte): Unit = row.setByte(ordinal, value) + override def setShort(ordinal: Int, value: Short): Unit = row.setShort(ordinal, value) + override def setInt(ordinal: Int, value: Int): Unit = row.setInt(ordinal, value) + override def setLong(ordinal: Int, value: Long): Unit = row.setLong(ordinal, value) + override def setDouble(ordinal: Int, value: Double): Unit = row.setDouble(ordinal, value) + override def setFloat(ordinal: Int, value: Float): Unit = row.setFloat(ordinal, value) + override def setDecimal(ordinal: Int, value: Decimal): Unit = + row.setDecimal(ordinal, value, value.precision) + } + + final class ArrayDataUpdater(array: ArrayData) extends CatalystDataUpdater { + override def set(ordinal: Int, value: Any): Unit = array.update(ordinal, value) + + override def setNullAt(ordinal: Int): Unit = array.setNullAt(ordinal) + override def setBoolean(ordinal: Int, value: Boolean): Unit = array.setBoolean(ordinal, value) + override def setByte(ordinal: Int, value: Byte): Unit = array.setByte(ordinal, value) + override def setShort(ordinal: Int, value: Short): Unit = array.setShort(ordinal, value) + override def setInt(ordinal: Int, value: Int): Unit = array.setInt(ordinal, value) + override def setLong(ordinal: Int, value: Long): Unit = array.setLong(ordinal, value) + override def setDouble(ordinal: Int, value: Double): Unit = array.setDouble(ordinal, value) + override def setFloat(ordinal: Int, value: Float): Unit = array.setFloat(ordinal, value) + override def setDecimal(ordinal: Int, value: Decimal): Unit = array.update(ordinal, value) + } +} + +object AvroDeserializer { + + // NOTE: Following methods have been renamed in Spark 3.2.1 [1] making [[AvroDeserializer]] implementation + // (which relies on it) be only compatible with the exact same version of [[DataSourceUtils]]. + // To make sure this implementation is compatible w/ all Spark versions w/in Spark 3.2.x branch, + // we're preemptively cloned those methods to make sure Hudi is compatible w/ Spark 3.2.0 as well as + // w/ Spark >= 3.2.1 + // + // [1] https://github.com/apache/spark/pull/34978 + + // Specification of rebase operation including `mode` and the time zone in which it is performed + case class RebaseSpec(mode: LegacyBehaviorPolicy.Value, originTimeZone: Option[String] = None) { + // Use the default JVM time zone for backward compatibility + def timeZone: String = originTimeZone.getOrElse(TimeZone.getDefault.getID) + } + + def createDateRebaseFuncInRead(rebaseMode: LegacyBehaviorPolicy.Value, + format: String): Int => Int = rebaseMode match { + case LegacyBehaviorPolicy.EXCEPTION => days: Int => + if (days < RebaseDateTime.lastSwitchJulianDay) { + throw DataSourceUtils.newRebaseExceptionInRead(format) + } + days + case LegacyBehaviorPolicy.LEGACY => RebaseDateTime.rebaseJulianToGregorianDays + case LegacyBehaviorPolicy.CORRECTED => identity[Int] + } + + def createTimestampRebaseFuncInRead(rebaseSpec: RebaseSpec, + format: String): Long => Long = rebaseSpec.mode match { + case LegacyBehaviorPolicy.EXCEPTION => micros: Long => + if (micros < RebaseDateTime.lastSwitchJulianTs) { + throw DataSourceUtils.newRebaseExceptionInRead(format) + } + micros + case LegacyBehaviorPolicy.LEGACY => micros: Long => + RebaseDateTime.rebaseJulianToGregorianMicros(TimeZone.getTimeZone(rebaseSpec.timeZone), micros) + case LegacyBehaviorPolicy.CORRECTED => identity[Long] + } +} diff --git a/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala new file mode 100644 index 0000000000000..a2ed346a97e1a --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala @@ -0,0 +1,450 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.avro.Conversions.DecimalConversion +import org.apache.avro.LogicalTypes.{LocalTimestampMicros, LocalTimestampMillis, TimestampMicros, TimestampMillis} +import org.apache.avro.{LogicalTypes, Schema} +import org.apache.avro.Schema.Type +import org.apache.avro.Schema.Type._ +import org.apache.avro.generic.GenericData.{EnumSymbol, Fixed, Record} +import org.apache.avro.util.Utf8 +import org.apache.spark.internal.Logging +import org.apache.spark.sql.avro.AvroSerializer.{createDateRebaseFuncInWrite, createTimestampRebaseFuncInWrite} +import org.apache.spark.sql.avro.AvroUtils.{AvroMatchedField, toFieldStr} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, SpecificInternalRow} +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, RebaseDateTime} +import org.apache.spark.sql.execution.datasources.DataSourceUtils +import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} +import org.apache.spark.sql.types._ + +import java.nio.ByteBuffer +import java.util.TimeZone +import scala.collection.JavaConverters._ + +/** + * A serializer to serialize data in catalyst format to data in avro format. + * + * NOTE: This code is borrowed from Spark 3.3.0 + * This code is borrowed, so that we can better control compatibility w/in Spark minor + * branches (3.2.x, 3.1.x, etc) + * + * NOTE: THIS IMPLEMENTATION HAS BEEN MODIFIED FROM ITS ORIGINAL VERSION WITH THE MODIFICATION + * BEING EXPLICITLY ANNOTATED INLINE. PLEASE MAKE SURE TO UNDERSTAND PROPERLY ALL THE + * MODIFICATIONS. + * + * PLEASE REFRAIN MAKING ANY CHANGES TO THIS CODE UNLESS ABSOLUTELY NECESSARY + */ +private[sql] class AvroSerializer(rootCatalystType: DataType, + rootAvroType: Schema, + nullable: Boolean, + positionalFieldMatch: Boolean, + datetimeRebaseMode: LegacyBehaviorPolicy.Value) extends Logging { + + def this(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean) = { + this(rootCatalystType, rootAvroType, nullable, positionalFieldMatch = false, + LegacyBehaviorPolicy.withName(SQLConf.get.getConf(SQLConf.AVRO_REBASE_MODE_IN_WRITE, + LegacyBehaviorPolicy.CORRECTED.toString))) + } + + def serialize(catalystData: Any): Any = { + converter.apply(catalystData) + } + + private val dateRebaseFunc = createDateRebaseFuncInWrite( + datetimeRebaseMode, "Avro") + + private val timestampRebaseFunc = createTimestampRebaseFuncInWrite( + datetimeRebaseMode, "Avro") + + private val converter: Any => Any = { + val actualAvroType = resolveNullableType(rootAvroType, nullable) + val baseConverter = try { + rootCatalystType match { + case st: StructType => + newStructConverter(st, actualAvroType, Nil, Nil).asInstanceOf[Any => Any] + case _ => + val tmpRow = new SpecificInternalRow(Seq(rootCatalystType)) + val converter = newConverter(rootCatalystType, actualAvroType, Nil, Nil) + (data: Any) => + tmpRow.update(0, data) + converter.apply(tmpRow, 0) + } + } catch { + case ise: IncompatibleSchemaException => throw new IncompatibleSchemaException( + s"Cannot convert SQL type ${rootCatalystType.sql} to Avro type $rootAvroType.", ise) + } + if (nullable) { + (data: Any) => + if (data == null) { + null + } else { + baseConverter.apply(data) + } + } else { + baseConverter + } + } + + private type Converter = (SpecializedGetters, Int) => Any + + private lazy val decimalConversions = new DecimalConversion() + + private def newConverter(catalystType: DataType, + avroType: Schema, + catalystPath: Seq[String], + avroPath: Seq[String]): Converter = { + val errorPrefix = s"Cannot convert SQL ${toFieldStr(catalystPath)} " + + s"to Avro ${toFieldStr(avroPath)} because " + (catalystType, avroType.getType) match { + case (NullType, NULL) => + (getter, ordinal) => null + case (BooleanType, BOOLEAN) => + (getter, ordinal) => getter.getBoolean(ordinal) + case (ByteType, INT) => + (getter, ordinal) => getter.getByte(ordinal).toInt + case (ShortType, INT) => + (getter, ordinal) => getter.getShort(ordinal).toInt + case (IntegerType, INT) => + (getter, ordinal) => getter.getInt(ordinal) + case (LongType, LONG) => + (getter, ordinal) => getter.getLong(ordinal) + case (FloatType, FLOAT) => + (getter, ordinal) => getter.getFloat(ordinal) + case (DoubleType, DOUBLE) => + (getter, ordinal) => getter.getDouble(ordinal) + case (d: DecimalType, FIXED) + if avroType.getLogicalType == LogicalTypes.decimal(d.precision, d.scale) => + (getter, ordinal) => + val decimal = getter.getDecimal(ordinal, d.precision, d.scale) + decimalConversions.toFixed(decimal.toJavaBigDecimal, avroType, + LogicalTypes.decimal(d.precision, d.scale)) + + case (d: DecimalType, BYTES) + if avroType.getLogicalType == LogicalTypes.decimal(d.precision, d.scale) => + (getter, ordinal) => + val decimal = getter.getDecimal(ordinal, d.precision, d.scale) + decimalConversions.toBytes(decimal.toJavaBigDecimal, avroType, + LogicalTypes.decimal(d.precision, d.scale)) + + case (StringType, ENUM) => + val enumSymbols: Set[String] = avroType.getEnumSymbols.asScala.toSet + (getter, ordinal) => + val data = getter.getUTF8String(ordinal).toString + if (!enumSymbols.contains(data)) { + throw new IncompatibleSchemaException(errorPrefix + + s""""$data" cannot be written since it's not defined in enum """ + + enumSymbols.mkString("\"", "\", \"", "\"")) + } + new EnumSymbol(avroType, data) + + case (StringType, STRING) => + (getter, ordinal) => new Utf8(getter.getUTF8String(ordinal).getBytes) + + case (BinaryType, FIXED) => + val size = avroType.getFixedSize + (getter, ordinal) => + val data: Array[Byte] = getter.getBinary(ordinal) + if (data.length != size) { + def len2str(len: Int): String = s"$len ${if (len > 1) "bytes" else "byte"}" + + throw new IncompatibleSchemaException(errorPrefix + len2str(data.length) + + " of binary data cannot be written into FIXED type with size of " + len2str(size)) + } + new Fixed(avroType, data) + + case (BinaryType, BYTES) => + (getter, ordinal) => ByteBuffer.wrap(getter.getBinary(ordinal)) + + case (DateType, INT) => + (getter, ordinal) => dateRebaseFunc(getter.getInt(ordinal)) + + case (TimestampType, LONG) => avroType.getLogicalType match { + // For backward compatibility, if the Avro type is Long and it is not logical type + // (the `null` case), output the timestamp value as with millisecond precision. + case null | _: TimestampMillis => (getter, ordinal) => + DateTimeUtils.microsToMillis(timestampRebaseFunc(getter.getLong(ordinal))) + case _: TimestampMicros => (getter, ordinal) => + timestampRebaseFunc(getter.getLong(ordinal)) + case other => throw new IncompatibleSchemaException(errorPrefix + + s"SQL type ${TimestampType.sql} cannot be converted to Avro logical type $other") + } + + case (TimestampNTZType, LONG) => avroType.getLogicalType match { + // To keep consistent with TimestampType, if the Avro type is Long and it is not + // logical type (the `null` case), output the TimestampNTZ as long value + // in millisecond precision. + case null | _: LocalTimestampMillis => (getter, ordinal) => + DateTimeUtils.microsToMillis(getter.getLong(ordinal)) + case _: LocalTimestampMicros => (getter, ordinal) => + getter.getLong(ordinal) + case other => throw new IncompatibleSchemaException(errorPrefix + + s"SQL type ${TimestampNTZType.sql} cannot be converted to Avro logical type $other") + } + + case (ArrayType(et, containsNull), ARRAY) => + val elementConverter = newConverter( + et, resolveNullableType(avroType.getElementType, containsNull), + catalystPath :+ "element", avroPath :+ "element") + (getter, ordinal) => { + val arrayData = getter.getArray(ordinal) + val len = arrayData.numElements() + val result = new Array[Any](len) + var i = 0 + while (i < len) { + if (containsNull && arrayData.isNullAt(i)) { + result(i) = null + } else { + result(i) = elementConverter(arrayData, i) + } + i += 1 + } + // avro writer is expecting a Java Collection, so we convert it into + // `ArrayList` backed by the specified array without data copying. + java.util.Arrays.asList(result: _*) + } + + case (st: StructType, RECORD) => + val structConverter = newStructConverter(st, avroType, catalystPath, avroPath) + val numFields = st.length + (getter, ordinal) => structConverter(getter.getStruct(ordinal, numFields)) + + //////////////////////////////////////////////////////////////////////////////////////////// + // Following section is amended to the original (Spark's) implementation + // >>> BEGINS + //////////////////////////////////////////////////////////////////////////////////////////// + + case (st: StructType, UNION) => + val unionConverter = newUnionConverter(st, avroType, catalystPath, avroPath) + val numFields = st.length + (getter, ordinal) => unionConverter(getter.getStruct(ordinal, numFields)) + + //////////////////////////////////////////////////////////////////////////////////////////// + // <<< ENDS + //////////////////////////////////////////////////////////////////////////////////////////// + + case (MapType(kt, vt, valueContainsNull), MAP) if kt == StringType => + val valueConverter = newConverter( + vt, resolveNullableType(avroType.getValueType, valueContainsNull), + catalystPath :+ "value", avroPath :+ "value") + (getter, ordinal) => + val mapData = getter.getMap(ordinal) + val len = mapData.numElements() + val result = new java.util.HashMap[String, Any](len) + val keyArray = mapData.keyArray() + val valueArray = mapData.valueArray() + var i = 0 + while (i < len) { + val key = keyArray.getUTF8String(i).toString + if (valueContainsNull && valueArray.isNullAt(i)) { + result.put(key, null) + } else { + result.put(key, valueConverter(valueArray, i)) + } + i += 1 + } + result + + case (_: YearMonthIntervalType, INT) => + (getter, ordinal) => getter.getInt(ordinal) + + case (_: DayTimeIntervalType, LONG) => + (getter, ordinal) => getter.getLong(ordinal) + + case _ => + throw new IncompatibleSchemaException(errorPrefix + + s"schema is incompatible (sqlType = ${catalystType.sql}, avroType = $avroType)") + } + } + + private def newStructConverter(catalystStruct: StructType, + avroStruct: Schema, + catalystPath: Seq[String], + avroPath: Seq[String]): InternalRow => Record = { + + val avroSchemaHelper = new AvroUtils.AvroSchemaHelper( + avroStruct, catalystStruct, avroPath, catalystPath, positionalFieldMatch) + + avroSchemaHelper.validateNoExtraCatalystFields(ignoreNullable = false) + avroSchemaHelper.validateNoExtraRequiredAvroFields() + + val (avroIndices, fieldConverters) = avroSchemaHelper.matchedFields.map { + case AvroMatchedField(catalystField, _, avroField) => + val converter = newConverter(catalystField.dataType, + resolveNullableType(avroField.schema(), catalystField.nullable), + catalystPath :+ catalystField.name, avroPath :+ avroField.name) + (avroField.pos(), converter) + }.toArray.unzip + + val numFields = catalystStruct.length + row: InternalRow => + val result = new Record(avroStruct) + var i = 0 + while (i < numFields) { + if (row.isNullAt(i)) { + result.put(avroIndices(i), null) + } else { + result.put(avroIndices(i), fieldConverters(i).apply(row, i)) + } + i += 1 + } + result + } + + //////////////////////////////////////////////////////////////////////////////////////////// + // Following section is amended to the original (Spark's) implementation + // >>> BEGINS + //////////////////////////////////////////////////////////////////////////////////////////// + + private def newUnionConverter(catalystStruct: StructType, + avroUnion: Schema, + catalystPath: Seq[String], + avroPath: Seq[String]): InternalRow => Any = { + if (avroUnion.getType != UNION || !canMapUnion(catalystStruct, avroUnion)) { + throw new IncompatibleSchemaException(s"Cannot convert Catalyst type $catalystStruct to " + + s"Avro type $avroUnion.") + } + val nullable = avroUnion.getTypes.size() > 0 && avroUnion.getTypes.get(0).getType == Type.NULL + val avroInnerTypes = if (nullable) { + avroUnion.getTypes.asScala.tail + } else { + avroUnion.getTypes.asScala + } + val fieldConverters = catalystStruct.zip(avroInnerTypes).map { + case (f1, f2) => newConverter(f1.dataType, f2, catalystPath, avroPath) + } + val numFields = catalystStruct.length + (row: InternalRow) => + var i = 0 + var result: Any = null + while (i < numFields) { + if (!row.isNullAt(i)) { + if (result != null) { + throw new IncompatibleSchemaException(s"Cannot convert Catalyst record $catalystStruct to " + + s"Avro union $avroUnion. Record has more than one optional values set") + } + result = fieldConverters(i).apply(row, i) + } + i += 1 + } + if (!nullable && result == null) { + throw new IncompatibleSchemaException(s"Cannot convert Catalyst record $catalystStruct to " + + s"Avro union $avroUnion. Record has no values set, while should have exactly one") + } + result + } + + private def canMapUnion(catalystStruct: StructType, avroStruct: Schema): Boolean = { + (avroStruct.getTypes.size() > 0 && + avroStruct.getTypes.get(0).getType == Type.NULL && + avroStruct.getTypes.size() - 1 == catalystStruct.length) || avroStruct.getTypes.size() == catalystStruct.length + } + + //////////////////////////////////////////////////////////////////////////////////////////// + // <<< ENDS + //////////////////////////////////////////////////////////////////////////////////////////// + + + /** + * Resolve a possibly nullable Avro Type. + * + * An Avro type is nullable when it is a [[UNION]] of two types: one null type and another + * non-null type. This method will check the nullability of the input Avro type and return the + * non-null type within when it is nullable. Otherwise it will return the input Avro type + * unchanged. It will throw an [[UnsupportedAvroTypeException]] when the input Avro type is an + * unsupported nullable type. + * + * It will also log a warning message if the nullability for Avro and catalyst types are + * different. + */ + private def resolveNullableType(avroType: Schema, nullable: Boolean): Schema = { + val (avroNullable, resolvedAvroType) = resolveAvroType(avroType) + warnNullabilityDifference(avroNullable, nullable) + resolvedAvroType + } + + /** + * Check the nullability of the input Avro type and resolve it when it is nullable. The first + * return value is a [[Boolean]] indicating if the input Avro type is nullable. The second + * return value is the possibly resolved type. + */ + private def resolveAvroType(avroType: Schema): (Boolean, Schema) = { + if (avroType.getType == Type.UNION) { + val fields = avroType.getTypes.asScala + val actualType = fields.filter(_.getType != Type.NULL) + if (fields.length == 2 && actualType.length == 1) { + (true, actualType.head) + } else { + // This is just a normal union, not used to designate nullability + (false, avroType) + } + } else { + (false, avroType) + } + } + + /** + * log a warning message if the nullability for Avro and catalyst types are different. + */ + private def warnNullabilityDifference(avroNullable: Boolean, catalystNullable: Boolean): Unit = { + if (avroNullable && !catalystNullable) { + logWarning("Writing Avro files with nullable Avro schema and non-nullable catalyst schema.") + } + if (!avroNullable && catalystNullable) { + logWarning("Writing Avro files with non-nullable Avro schema and nullable catalyst " + + "schema will throw runtime exception if there is a record with null value.") + } + } +} + +object AvroSerializer { + + // NOTE: Following methods have been renamed in Spark 3.2.1 [1] making [[AvroSerializer]] implementation + // (which relies on it) be only compatible with the exact same version of [[DataSourceUtils]]. + // To make sure this implementation is compatible w/ all Spark versions w/in Spark 3.2.x branch, + // we're preemptively cloned those methods to make sure Hudi is compatible w/ Spark 3.2.0 as well as + // w/ Spark >= 3.2.1 + // + // [1] https://github.com/apache/spark/pull/34978 + + def createDateRebaseFuncInWrite(rebaseMode: LegacyBehaviorPolicy.Value, + format: String): Int => Int = rebaseMode match { + case LegacyBehaviorPolicy.EXCEPTION => days: Int => + if (days < RebaseDateTime.lastSwitchGregorianDay) { + throw DataSourceUtils.newRebaseExceptionInWrite(format) + } + days + case LegacyBehaviorPolicy.LEGACY => RebaseDateTime.rebaseGregorianToJulianDays + case LegacyBehaviorPolicy.CORRECTED => identity[Int] + } + + def createTimestampRebaseFuncInWrite(rebaseMode: LegacyBehaviorPolicy.Value, + format: String): Long => Long = rebaseMode match { + case LegacyBehaviorPolicy.EXCEPTION => micros: Long => + if (micros < RebaseDateTime.lastSwitchGregorianTs) { + throw DataSourceUtils.newRebaseExceptionInWrite(format) + } + micros + case LegacyBehaviorPolicy.LEGACY => + val timeZone = SQLConf.get.sessionLocalTimeZone + RebaseDateTime.rebaseGregorianToJulianMicros(TimeZone.getTimeZone(timeZone), _) + case LegacyBehaviorPolicy.CORRECTED => identity[Long] + } + +} diff --git a/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala new file mode 100644 index 0000000000000..b9845c491dc0c --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.util.Locale + +import scala.collection.JavaConverters._ + +import org.apache.avro.Schema +import org.apache.avro.file. FileReader +import org.apache.avro.generic.GenericRecord + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +/** + * NOTE: This code is borrowed from Spark 3.3.0 + * This code is borrowed, so that we can better control compatibility w/in Spark minor + * branches (3.2.x, 3.1.x, etc) + * + * PLEASE REFRAIN MAKING ANY CHANGES TO THIS CODE UNLESS ABSOLUTELY NECESSARY + */ +private[sql] object AvroUtils extends Logging { + + def supportsDataType(dataType: DataType): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportsDataType(f.dataType) } + + case ArrayType(elementType, _) => supportsDataType(elementType) + + case MapType(keyType, valueType, _) => + supportsDataType(keyType) && supportsDataType(valueType) + + case udt: UserDefinedType[_] => supportsDataType(udt.sqlType) + + case _: NullType => true + + case _ => false + } + + // The trait provides iterator-like interface for reading records from an Avro file, + // deserializing and returning them as internal rows. + trait RowReader { + protected val fileReader: FileReader[GenericRecord] + protected val deserializer: AvroDeserializer + protected val stopPosition: Long + + private[this] var completed = false + private[this] var currentRow: Option[InternalRow] = None + + def hasNextRow: Boolean = { + while (!completed && currentRow.isEmpty) { + val r = fileReader.hasNext && !fileReader.pastSync(stopPosition) + if (!r) { + fileReader.close() + completed = true + currentRow = None + } else { + val record = fileReader.next() + // the row must be deserialized in hasNextRow, because AvroDeserializer#deserialize + // potentially filters rows + currentRow = deserializer.deserialize(record).asInstanceOf[Option[InternalRow]] + } + } + currentRow.isDefined + } + + def nextRow: InternalRow = { + if (currentRow.isEmpty) { + hasNextRow + } + val returnRow = currentRow + currentRow = None // free up hasNextRow to consume more Avro records, if not exhausted + returnRow.getOrElse { + throw new NoSuchElementException("next on empty iterator") + } + } + } + + /** Wrapper for a pair of matched fields, one Catalyst and one corresponding Avro field. */ + private[sql] case class AvroMatchedField( + catalystField: StructField, + catalystPosition: Int, + avroField: Schema.Field) + + /** + * Helper class to perform field lookup/matching on Avro schemas. + * + * This will match `avroSchema` against `catalystSchema`, attempting to find a matching field in + * the Avro schema for each field in the Catalyst schema and vice-versa, respecting settings for + * case sensitivity. The match results can be accessed using the getter methods. + * + * @param avroSchema The schema in which to search for fields. Must be of type RECORD. + * @param catalystSchema The Catalyst schema to use for matching. + * @param avroPath The seq of parent field names leading to `avroSchema`. + * @param catalystPath The seq of parent field names leading to `catalystSchema`. + * @param positionalFieldMatch If true, perform field matching in a positional fashion + * (structural comparison between schemas, ignoring names); + * otherwise, perform field matching using field names. + */ + class AvroSchemaHelper( + avroSchema: Schema, + catalystSchema: StructType, + avroPath: Seq[String], + catalystPath: Seq[String], + positionalFieldMatch: Boolean) { + if (avroSchema.getType != Schema.Type.RECORD) { + throw new IncompatibleSchemaException( + s"Attempting to treat ${avroSchema.getName} as a RECORD, but it was: ${avroSchema.getType}") + } + + private[this] val avroFieldArray = avroSchema.getFields.asScala.toArray + private[this] val fieldMap = avroSchema.getFields.asScala + .groupBy(_.name.toLowerCase(Locale.ROOT)) + .mapValues(_.toSeq) // toSeq needed for scala 2.13 + + /** The fields which have matching equivalents in both Avro and Catalyst schemas. */ + val matchedFields: Seq[AvroMatchedField] = catalystSchema.zipWithIndex.flatMap { + case (sqlField, sqlPos) => + getAvroField(sqlField.name, sqlPos).map(AvroMatchedField(sqlField, sqlPos, _)) + } + + /** + * Validate that there are no Catalyst fields which don't have a matching Avro field, throwing + * [[IncompatibleSchemaException]] if such extra fields are found. If `ignoreNullable` is false, + * consider nullable Catalyst fields to be eligible to be an extra field; otherwise, + * ignore nullable Catalyst fields when checking for extras. + */ + def validateNoExtraCatalystFields(ignoreNullable: Boolean): Unit = + catalystSchema.zipWithIndex.foreach { case (sqlField, sqlPos) => + if (getAvroField(sqlField.name, sqlPos).isEmpty && + (!ignoreNullable || !sqlField.nullable)) { + if (positionalFieldMatch) { + throw new IncompatibleSchemaException("Cannot find field at position " + + s"$sqlPos of ${toFieldStr(avroPath)} from Avro schema (using positional matching)") + } else { + throw new IncompatibleSchemaException( + s"Cannot find ${toFieldStr(catalystPath :+ sqlField.name)} in Avro schema") + } + } + } + + /** + * Validate that there are no Avro fields which don't have a matching Catalyst field, throwing + * [[IncompatibleSchemaException]] if such extra fields are found. Only required (non-nullable) + * fields are checked; nullable fields are ignored. + */ + def validateNoExtraRequiredAvroFields(): Unit = { + val extraFields = avroFieldArray.toSet -- matchedFields.map(_.avroField) + extraFields.filterNot(isNullable).foreach { extraField => + if (positionalFieldMatch) { + throw new IncompatibleSchemaException(s"Found field '${extraField.name()}' at position " + + s"${extraField.pos()} of ${toFieldStr(avroPath)} from Avro schema but there is no " + + s"match in the SQL schema at ${toFieldStr(catalystPath)} (using positional matching)") + } else { + throw new IncompatibleSchemaException( + s"Found ${toFieldStr(avroPath :+ extraField.name())} in Avro schema but there is no " + + "match in the SQL schema") + } + } + } + + /** + * Extract a single field from the contained avro schema which has the desired field name, + * performing the matching with proper case sensitivity according to SQLConf.resolver. + * + * @param name The name of the field to search for. + * @return `Some(match)` if a matching Avro field is found, otherwise `None`. + */ + private[avro] def getFieldByName(name: String): Option[Schema.Field] = { + + // get candidates, ignoring case of field name + val candidates = fieldMap.getOrElse(name.toLowerCase(Locale.ROOT), Seq.empty) + + // search candidates, taking into account case sensitivity settings + candidates.filter(f => SQLConf.get.resolver(f.name(), name)) match { + case Seq(avroField) => Some(avroField) + case Seq() => None + case matches => throw new IncompatibleSchemaException(s"Searching for '$name' in Avro " + + s"schema at ${toFieldStr(avroPath)} gave ${matches.size} matches. Candidates: " + + matches.map(_.name()).mkString("[", ", ", "]") + ) + } + } + + /** Get the Avro field corresponding to the provided Catalyst field name/position, if any. */ + def getAvroField(fieldName: String, catalystPos: Int): Option[Schema.Field] = { + if (positionalFieldMatch) { + avroFieldArray.lift(catalystPos) + } else { + getFieldByName(fieldName) + } + } + } + + /** + * Convert a sequence of hierarchical field names (like `Seq(foo, bar)`) into a human-readable + * string representing the field, like "field 'foo.bar'". If `names` is empty, the string + * "top-level record" is returned. + */ + private[avro] def toFieldStr(names: Seq[String]): String = names match { + case Seq() => "top-level record" + case n => s"field '${n.mkString(".")}'" + } + + /** Return true iff `avroField` is nullable, i.e. `UNION` type and has `NULL` as an option. */ + private[avro] def isNullable(avroField: Schema.Field): Boolean = + avroField.schema().getType == Schema.Type.UNION && + avroField.schema().getTypes.asScala.exists(_.getType == Schema.Type.NULL) +} diff --git a/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/avro/HoodieSpark3_5AvroDeserializer.scala b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/avro/HoodieSpark3_5AvroDeserializer.scala new file mode 100644 index 0000000000000..c99b1a499f69c --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/avro/HoodieSpark3_5AvroDeserializer.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.avro.Schema +import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} +import org.apache.spark.sql.types.DataType + +class HoodieSpark3_5AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) + extends HoodieAvroDeserializer { + + private val avroDeserializer = new AvroDeserializer(rootAvroType, rootCatalystType, + SQLConf.get.getConf(SQLConf.AVRO_REBASE_MODE_IN_READ, LegacyBehaviorPolicy.CORRECTED.toString)) + + def deserialize(data: Any): Option[Any] = avroDeserializer.deserialize(data) +} diff --git a/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/avro/HoodieSpark3_5AvroSerializer.scala b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/avro/HoodieSpark3_5AvroSerializer.scala new file mode 100644 index 0000000000000..639f16cb3c966 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/avro/HoodieSpark3_5AvroSerializer.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import org.apache.avro.Schema +import org.apache.spark.sql.types.DataType + +class HoodieSpark3_5AvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean) + extends HoodieAvroSerializer { + + val avroSerializer = new AvroSerializer(rootCatalystType, rootAvroType, nullable) + + override def serialize(catalystData: Any): Any = avroSerializer.serialize(catalystData) +} diff --git a/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/execution/datasources/HoodieSpark35PartitionedFileUtils.scala b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/execution/datasources/HoodieSpark35PartitionedFileUtils.scala new file mode 100644 index 0000000000000..611ccf7c0b1ad --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/execution/datasources/HoodieSpark35PartitionedFileUtils.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.spark.paths.SparkPath +import org.apache.spark.sql.catalyst.InternalRow + +/** + * Utils on Spark [[PartitionedFile]] and [[PartitionDirectory]] for Spark 3.5. + */ +object HoodieSpark35PartitionedFileUtils extends HoodieSparkPartitionedFileUtils { + override def getPathFromPartitionedFile(partitionedFile: PartitionedFile): Path = { + partitionedFile.filePath.toPath + } + + override def getStringPathFromPartitionedFile(partitionedFile: PartitionedFile): String = { + partitionedFile.filePath.toString + } + + override def createPartitionedFile(partitionValues: InternalRow, + filePath: Path, + start: Long, + length: Long): PartitionedFile = { + PartitionedFile(partitionValues, SparkPath.fromPath(filePath), start, length) + } + + override def toFileStatuses(partitionDirs: Seq[PartitionDirectory]): Seq[FileStatus] = { + partitionDirs.flatMap(_.files).map(_.fileStatus) + } + + override def newPartitionDirectory(internalRow: InternalRow, statuses: Seq[FileStatus]): PartitionDirectory = { + PartitionDirectory(internalRow, statuses.toArray) + } +} diff --git a/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/execution/datasources/Spark35NestedSchemaPruning.scala b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/execution/datasources/Spark35NestedSchemaPruning.scala new file mode 100644 index 0000000000000..966ade0db79c0 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/execution/datasources/Spark35NestedSchemaPruning.scala @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.hudi.{HoodieBaseRelation, SparkAdapterSupport} +import org.apache.spark.sql.HoodieSpark3CatalystPlanUtils +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression, NamedExpression, ProjectionOverSchema} +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.sources.BaseRelation +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} +import org.apache.spark.sql.util.SchemaUtils.restoreOriginalOutputNames + +/** + * Prunes unnecessary physical columns given a [[PhysicalOperation]] over a data source relation. + * By "physical column", we mean a column as defined in the data source format like Parquet format + * or ORC format. For example, in Spark SQL, a root-level Parquet column corresponds to a SQL + * column, and a nested Parquet column corresponds to a [[StructField]]. + * + * NOTE: This class is borrowed from Spark 3.2.1, with modifications adapting it to handle [[HoodieBaseRelation]], + * instead of [[HadoopFsRelation]] + */ +class Spark35NestedSchemaPruning extends Rule[LogicalPlan] { + import org.apache.spark.sql.catalyst.expressions.SchemaPruning._ + + override def apply(plan: LogicalPlan): LogicalPlan = + if (conf.nestedSchemaPruningEnabled) { + apply0(plan) + } else { + plan + } + + private def apply0(plan: LogicalPlan): LogicalPlan = + plan transformDown { + case op @ PhysicalOperation(projects, filters, + // NOTE: This is modified to accommodate for Hudi's custom relations, given that original + // [[NestedSchemaPruning]] rule is tightly coupled w/ [[HadoopFsRelation]] + // TODO generalize to any file-based relation + l @ LogicalRelation(relation: HoodieBaseRelation, _, _, _)) + if relation.canPruneRelationSchema => + + prunePhysicalColumns(l.output, projects, filters, relation.dataSchema, + prunedDataSchema => { + val prunedRelation = + relation.updatePrunedDataSchema(prunedSchema = prunedDataSchema) + buildPrunedRelation(l, prunedRelation) + }).getOrElse(op) + } + + /** + * This method returns optional logical plan. `None` is returned if no nested field is required or + * all nested fields are required. + */ + private def prunePhysicalColumns(output: Seq[AttributeReference], + projects: Seq[NamedExpression], + filters: Seq[Expression], + dataSchema: StructType, + outputRelationBuilder: StructType => LogicalRelation): Option[LogicalPlan] = { + val (normalizedProjects, normalizedFilters) = + normalizeAttributeRefNames(output, projects, filters) + val requestedRootFields = identifyRootFields(normalizedProjects, normalizedFilters) + + // If requestedRootFields includes a nested field, continue. Otherwise, + // return op + if (requestedRootFields.exists { root: RootField => !root.derivedFromAtt }) { + val prunedDataSchema = pruneSchema(dataSchema, requestedRootFields) + + // If the data schema is different from the pruned data schema, continue. Otherwise, + // return op. We effect this comparison by counting the number of "leaf" fields in + // each schemata, assuming the fields in prunedDataSchema are a subset of the fields + // in dataSchema. + if (countLeaves(dataSchema) > countLeaves(prunedDataSchema)) { + val planUtils = SparkAdapterSupport.sparkAdapter.getCatalystPlanUtils.asInstanceOf[HoodieSpark3CatalystPlanUtils] + + val prunedRelation = outputRelationBuilder(prunedDataSchema) + val projectionOverSchema = planUtils.projectOverSchema(prunedDataSchema, AttributeSet(output)) + + Some(buildNewProjection(projects, normalizedProjects, normalizedFilters, + prunedRelation, projectionOverSchema)) + } else { + None + } + } else { + None + } + } + + /** + * Normalizes the names of the attribute references in the given projects and filters to reflect + * the names in the given logical relation. This makes it possible to compare attributes and + * fields by name. Returns a tuple with the normalized projects and filters, respectively. + */ + private def normalizeAttributeRefNames(output: Seq[AttributeReference], + projects: Seq[NamedExpression], + filters: Seq[Expression]): (Seq[NamedExpression], Seq[Expression]) = { + val normalizedAttNameMap = output.map(att => (att.exprId, att.name)).toMap + val normalizedProjects = projects.map(_.transform { + case att: AttributeReference if normalizedAttNameMap.contains(att.exprId) => + att.withName(normalizedAttNameMap(att.exprId)) + }).map { case expr: NamedExpression => expr } + val normalizedFilters = filters.map(_.transform { + case att: AttributeReference if normalizedAttNameMap.contains(att.exprId) => + att.withName(normalizedAttNameMap(att.exprId)) + }) + (normalizedProjects, normalizedFilters) + } + + /** + * Builds the new output [[Project]] Spark SQL operator that has the `leafNode`. + */ + private def buildNewProjection(projects: Seq[NamedExpression], + normalizedProjects: Seq[NamedExpression], + filters: Seq[Expression], + prunedRelation: LogicalRelation, + projectionOverSchema: ProjectionOverSchema): Project = { + // Construct a new target for our projection by rewriting and + // including the original filters where available + val projectionChild = + if (filters.nonEmpty) { + val projectedFilters = filters.map(_.transformDown { + case projectionOverSchema(expr) => expr + }) + val newFilterCondition = projectedFilters.reduce(And) + Filter(newFilterCondition, prunedRelation) + } else { + prunedRelation + } + + // Construct the new projections of our Project by + // rewriting the original projections + val newProjects = normalizedProjects.map(_.transformDown { + case projectionOverSchema(expr) => expr + }).map { case expr: NamedExpression => expr } + + if (log.isDebugEnabled) { + logDebug(s"New projects:\n${newProjects.map(_.treeString).mkString("\n")}") + } + + Project(restoreOriginalOutputNames(newProjects, projects.map(_.name)), projectionChild) + } + + /** + * Builds a pruned logical relation from the output of the output relation and the schema of the + * pruned base relation. + */ + private def buildPrunedRelation(outputRelation: LogicalRelation, + prunedBaseRelation: BaseRelation): LogicalRelation = { + val prunedOutput = getPrunedOutput(outputRelation.output, prunedBaseRelation.schema) + outputRelation.copy(relation = prunedBaseRelation, output = prunedOutput) + } + + // Prune the given output to make it consistent with `requiredSchema`. + private def getPrunedOutput(output: Seq[AttributeReference], + requiredSchema: StructType): Seq[AttributeReference] = { + // We need to replace the expression ids of the pruned relation output attributes + // with the expression ids of the original relation output attributes so that + // references to the original relation's output are not broken + val outputIdMap = output.map(att => (att.name, att.exprId)).toMap + DataTypeUtils.toAttributes(requiredSchema) + .map { + case att if outputIdMap.contains(att.name) => + att.withExprId(outputIdMap(att.name)) + case att => att + } + } + + /** + * Counts the "leaf" fields of the given dataType. Informally, this is the + * number of fields of non-complex data type in the tree representation of + * [[DataType]]. + */ + private def countLeaves(dataType: DataType): Int = { + dataType match { + case array: ArrayType => countLeaves(array.elementType) + case map: MapType => countLeaves(map.keyType) + countLeaves(map.valueType) + case struct: StructType => + struct.map(field => countLeaves(field.dataType)).sum + case _ => 1 + } + } +} diff --git a/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark35DataSourceUtils.scala b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark35DataSourceUtils.scala new file mode 100644 index 0000000000000..4e08f975eefbf --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark35DataSourceUtils.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.sql.SPARK_VERSION_METADATA_KEY +import org.apache.spark.sql.internal.{SQLConf, LegacyBehaviorPolicy} +import org.apache.spark.util.Utils + +object Spark35DataSourceUtils { + + /** + * NOTE: This method was copied from [[Spark32PlusDataSourceUtils]], and is required to maintain runtime + * compatibility against Spark 3.5.0 + */ + // scalastyle:off + def int96RebaseMode(lookupFileMeta: String => String, + modeByConfig: String): LegacyBehaviorPolicy.Value = { + if (Utils.isTesting && SQLConf.get.getConfString("spark.test.forceNoRebase", "") == "true") { + return LegacyBehaviorPolicy.CORRECTED + } + // If there is no version, we return the mode specified by the config. + Option(lookupFileMeta(SPARK_VERSION_METADATA_KEY)).map { version => + // Files written by Spark 3.0 and earlier follow the legacy hybrid calendar and we need to + // rebase the INT96 timestamp values. + // Files written by Spark 3.1 and latter may also need the rebase if they were written with + // the "LEGACY" rebase mode. + if (version < "3.1.0" || lookupFileMeta("org.apache.spark.legacyINT96") != null) { + LegacyBehaviorPolicy.LEGACY + } else { + LegacyBehaviorPolicy.CORRECTED + } + }.getOrElse(LegacyBehaviorPolicy.withName(modeByConfig)) + } + // scalastyle:on + + /** + * NOTE: This method was copied from Spark 3.2.0, and is required to maintain runtime + * compatibility against Spark 3.2.0 + */ + // scalastyle:off + def datetimeRebaseMode(lookupFileMeta: String => String, + modeByConfig: String): LegacyBehaviorPolicy.Value = { + if (Utils.isTesting && SQLConf.get.getConfString("spark.test.forceNoRebase", "") == "true") { + return LegacyBehaviorPolicy.CORRECTED + } + // If there is no version, we return the mode specified by the config. + Option(lookupFileMeta(SPARK_VERSION_METADATA_KEY)).map { version => + // Files written by Spark 2.4 and earlier follow the legacy hybrid calendar and we need to + // rebase the datetime values. + // Files written by Spark 3.0 and latter may also need the rebase if they were written with + // the "LEGACY" rebase mode. + if (version < "3.0.0" || lookupFileMeta("org.apache.spark.legacyDateTime") != null) { + LegacyBehaviorPolicy.LEGACY + } else { + LegacyBehaviorPolicy.CORRECTED + } + }.getOrElse(LegacyBehaviorPolicy.withName(modeByConfig)) + } + // scalastyle:on + +} diff --git a/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark35LegacyHoodieParquetFileFormat.scala b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark35LegacyHoodieParquetFileFormat.scala new file mode 100644 index 0000000000000..dd70aa08b8562 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/Spark35LegacyHoodieParquetFileFormat.scala @@ -0,0 +1,536 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.mapred.FileSplit +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl +import org.apache.hadoop.mapreduce.{JobID, TaskAttemptID, TaskID, TaskType} +import org.apache.hudi.HoodieSparkUtils +import org.apache.hudi.client.utils.SparkInternalSchemaConverter +import org.apache.hudi.common.fs.FSUtils +import org.apache.hudi.common.util.InternalSchemaCache +import org.apache.hudi.common.util.StringUtils.isNullOrEmpty +import org.apache.hudi.common.util.collection.Pair +import org.apache.hudi.internal.schema.InternalSchema +import org.apache.hudi.internal.schema.action.InternalSchemaMerger +import org.apache.hudi.internal.schema.utils.{InternalSchemaUtils, SerDeHelper} +import org.apache.parquet.filter2.compat.FilterCompat +import org.apache.parquet.filter2.predicate.FilterApi +import org.apache.parquet.format.converter.ParquetMetadataConverter.SKIP_ROW_GROUPS +import org.apache.parquet.hadoop.{ParquetInputFormat, ParquetRecordReader} +import org.apache.spark.TaskContext +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.expressions.{Cast, JoinedRow} +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.WholeStageCodegenExec +import org.apache.spark.sql.execution.datasources.parquet.Spark35LegacyHoodieParquetFileFormat._ +import org.apache.spark.sql.execution.datasources.{DataSourceUtils, PartitionedFile, RecordReaderIterator} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.{AtomicType, DataType, StructField, StructType} +import org.apache.spark.util.SerializableConfiguration +/** + * This class is an extension of [[ParquetFileFormat]] overriding Spark-specific behavior + * that's not possible to customize in any other way + * + * NOTE: This is a version of [[AvroDeserializer]] impl from Spark 3.2.1 w/ w/ the following changes applied to it: + *
    + *
  1. Avoiding appending partition values to the rows read from the data file
  2. + *
  3. Schema on-read
  4. + *
+ */ +class Spark35LegacyHoodieParquetFileFormat(private val shouldAppendPartitionValues: Boolean) extends ParquetFileFormat { + + override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = { + val conf = sparkSession.sessionState.conf + conf.parquetVectorizedReaderEnabled && schema.forall(_.dataType.isInstanceOf[AtomicType]) + } + + def supportsColumnar(sparkSession: SparkSession, schema: StructType): Boolean = { + val conf = sparkSession.sessionState.conf + // Only output columnar if there is WSCG to read it. + val requiredWholeStageCodegenSettings = + conf.wholeStageEnabled && !WholeStageCodegenExec.isTooManyFields(conf, schema) + requiredWholeStageCodegenSettings && + supportBatch(sparkSession, schema) + } + + override def buildReaderWithPartitionValues(sparkSession: SparkSession, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String], + hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { + hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) + hadoopConf.set( + ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA, + requiredSchema.json) + hadoopConf.set( + ParquetWriteSupport.SPARK_ROW_SCHEMA, + requiredSchema.json) + hadoopConf.set( + SQLConf.SESSION_LOCAL_TIMEZONE.key, + sparkSession.sessionState.conf.sessionLocalTimeZone) + hadoopConf.setBoolean( + SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key, + sparkSession.sessionState.conf.nestedSchemaPruningEnabled) + hadoopConf.setBoolean( + SQLConf.CASE_SENSITIVE.key, + sparkSession.sessionState.conf.caseSensitiveAnalysis) + + ParquetWriteSupport.setSchema(requiredSchema, hadoopConf) + + // Sets flags for `ParquetToSparkSchemaConverter` + hadoopConf.setBoolean( + SQLConf.PARQUET_BINARY_AS_STRING.key, + sparkSession.sessionState.conf.isParquetBinaryAsString) + hadoopConf.setBoolean( + SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, + sparkSession.sessionState.conf.isParquetINT96AsTimestamp) + // Using string value of this conf to preserve compatibility across spark versions. + hadoopConf.setBoolean( + SQLConf.LEGACY_PARQUET_NANOS_AS_LONG.key, + sparkSession.sessionState.conf.getConfString( + SQLConf.LEGACY_PARQUET_NANOS_AS_LONG.key, + SQLConf.LEGACY_PARQUET_NANOS_AS_LONG.defaultValueString).toBoolean + ) + hadoopConf.setBoolean(SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.key, sparkSession.sessionState.conf.parquetInferTimestampNTZEnabled) + hadoopConf.setBoolean(SQLConf.LEGACY_PARQUET_NANOS_AS_LONG.key, sparkSession.sessionState.conf.legacyParquetNanosAsLong) + val internalSchemaStr = hadoopConf.get(SparkInternalSchemaConverter.HOODIE_QUERY_SCHEMA) + // For Spark DataSource v1, there's no Physical Plan projection/schema pruning w/in Spark itself, + // therefore it's safe to do schema projection here + if (!isNullOrEmpty(internalSchemaStr)) { + val prunedInternalSchemaStr = + pruneInternalSchema(internalSchemaStr, requiredSchema) + hadoopConf.set(SparkInternalSchemaConverter.HOODIE_QUERY_SCHEMA, prunedInternalSchemaStr) + } + + val broadcastedHadoopConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + + // TODO: if you move this into the closure it reverts to the default values. + // If true, enable using the custom RecordReader for parquet. This only works for + // a subset of the types (no complex types). + val resultSchema = StructType(partitionSchema.fields ++ requiredSchema.fields) + val sqlConf = sparkSession.sessionState.conf + val enableOffHeapColumnVector = sqlConf.offHeapColumnVectorEnabled + val enableVectorizedReader: Boolean = + sqlConf.parquetVectorizedReaderEnabled && + resultSchema.forall(_.dataType.isInstanceOf[AtomicType]) + val enableRecordFilter: Boolean = sqlConf.parquetRecordFilterEnabled + val timestampConversion: Boolean = sqlConf.isParquetINT96TimestampConversion + val capacity = sqlConf.parquetVectorizedReaderBatchSize + val enableParquetFilterPushDown: Boolean = sqlConf.parquetFilterPushDown + val pushDownDate = sqlConf.parquetFilterPushDownDate + val pushDownTimestamp = sqlConf.parquetFilterPushDownTimestamp + val pushDownDecimal = sqlConf.parquetFilterPushDownDecimal + val pushDownStringStartWith = sqlConf.parquetFilterPushDownStringPredicate + val pushDownInFilterThreshold = sqlConf.parquetFilterPushDownInFilterThreshold + val isCaseSensitive = sqlConf.caseSensitiveAnalysis + val parquetOptions = new ParquetOptions(options, sparkSession.sessionState.conf) + val datetimeRebaseModeInRead = parquetOptions.datetimeRebaseModeInRead + val int96RebaseModeInRead = parquetOptions.int96RebaseModeInRead + val timeZoneId = Option(sqlConf.sessionLocalTimeZone) + // Should always be set by FileSourceScanExec creating this. + // Check conf before checking option, to allow working around an issue by changing conf. + val returningBatch = sparkSession.sessionState.conf.parquetVectorizedReaderEnabled && + supportsColumnar(sparkSession, resultSchema).toString.equals("true") + + + (file: PartitionedFile) => { + assert(!shouldAppendPartitionValues || file.partitionValues.numFields == partitionSchema.size) + + val filePath = file.filePath.toPath + val split = new FileSplit(filePath, file.start, file.length, Array.empty[String]) + + val sharedConf = broadcastedHadoopConf.value.value + + // Fetch internal schema + val internalSchemaStr = sharedConf.get(SparkInternalSchemaConverter.HOODIE_QUERY_SCHEMA) + // Internal schema has to be pruned at this point + val querySchemaOption = SerDeHelper.fromJson(internalSchemaStr) + + var shouldUseInternalSchema = !isNullOrEmpty(internalSchemaStr) && querySchemaOption.isPresent + + val tablePath = sharedConf.get(SparkInternalSchemaConverter.HOODIE_TABLE_PATH) + val fileSchema = if (shouldUseInternalSchema) { + val commitInstantTime = FSUtils.getCommitTime(filePath.getName).toLong; + val validCommits = sharedConf.get(SparkInternalSchemaConverter.HOODIE_VALID_COMMITS_LIST) + InternalSchemaCache.getInternalSchemaByVersionId(commitInstantTime, tablePath, sharedConf, if (validCommits == null) "" else validCommits) + } else { + null + } + + lazy val footerFileMetaData = + ParquetFooterReader.readFooter(sharedConf, filePath, SKIP_ROW_GROUPS).getFileMetaData + // Try to push down filters when filter push-down is enabled. + val pushed = if (enableParquetFilterPushDown) { + val parquetSchema = footerFileMetaData.getSchema + val parquetFilters = if (HoodieSparkUtils.gteqSpark3_2_1) { + // NOTE: Below code could only be compiled against >= Spark 3.2.1, + // and unfortunately won't compile against Spark 3.2.0 + // However this code is runtime-compatible w/ both Spark 3.2.0 and >= Spark 3.2.1 + val datetimeRebaseSpec = + DataSourceUtils.datetimeRebaseSpec(footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead) + new ParquetFilters( + parquetSchema, + pushDownDate, + pushDownTimestamp, + pushDownDecimal, + pushDownStringStartWith, + pushDownInFilterThreshold, + isCaseSensitive, + datetimeRebaseSpec) + } else { + // Spark 3.2.0 + val datetimeRebaseMode = + Spark35DataSourceUtils.datetimeRebaseMode(footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead) + createParquetFilters( + parquetSchema, + pushDownDate, + pushDownTimestamp, + pushDownDecimal, + pushDownStringStartWith, + pushDownInFilterThreshold, + isCaseSensitive, + datetimeRebaseMode) + } + filters.map(rebuildFilterFromParquet(_, fileSchema, querySchemaOption.orElse(null))) + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. + .flatMap(parquetFilters.createFilter) + .reduceOption(FilterApi.and) + } else { + None + } + + // PARQUET_INT96_TIMESTAMP_CONVERSION says to apply timezone conversions to int96 timestamps' + // *only* if the file was created by something other than "parquet-mr", so check the actual + // writer here for this file. We have to do this per-file, as each file in the table may + // have different writers. + // Define isCreatedByParquetMr as function to avoid unnecessary parquet footer reads. + def isCreatedByParquetMr: Boolean = + footerFileMetaData.getCreatedBy().startsWith("parquet-mr") + + val convertTz = + if (timestampConversion && !isCreatedByParquetMr) { + Some(DateTimeUtils.getZoneId(sharedConf.get(SQLConf.SESSION_LOCAL_TIMEZONE.key))) + } else { + None + } + + val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) + + // Clone new conf + val hadoopAttemptConf = new Configuration(broadcastedHadoopConf.value.value) + val typeChangeInfos: java.util.Map[Integer, Pair[DataType, DataType]] = if (shouldUseInternalSchema) { + val mergedInternalSchema = new InternalSchemaMerger(fileSchema, querySchemaOption.get(), true, true).mergeSchema() + val mergedSchema = SparkInternalSchemaConverter.constructSparkSchemaFromInternalSchema(mergedInternalSchema) + + hadoopAttemptConf.set(ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA, mergedSchema.json) + + SparkInternalSchemaConverter.collectTypeChangedCols(querySchemaOption.get(), mergedInternalSchema) + } else { + val (implicitTypeChangeInfo, sparkRequestSchema) = HoodieParquetFileFormatHelper.buildImplicitSchemaChangeInfo(hadoopAttemptConf, footerFileMetaData, requiredSchema) + if (!implicitTypeChangeInfo.isEmpty) { + shouldUseInternalSchema = true + hadoopAttemptConf.set(ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA, sparkRequestSchema.json) + } + implicitTypeChangeInfo + } + + val hadoopAttemptContext = + new TaskAttemptContextImpl(hadoopAttemptConf, attemptId) + + // Try to push down filters when filter push-down is enabled. + // Notice: This push-down is RowGroups level, not individual records. + if (pushed.isDefined) { + ParquetInputFormat.setFilterPredicate(hadoopAttemptContext.getConfiguration, pushed.get) + } + val taskContext = Option(TaskContext.get()) + if (enableVectorizedReader) { + val vectorizedReader = + if (shouldUseInternalSchema) { + val int96RebaseSpec = + DataSourceUtils.int96RebaseSpec(footerFileMetaData.getKeyValueMetaData.get, int96RebaseModeInRead) + val datetimeRebaseSpec = + DataSourceUtils.datetimeRebaseSpec(footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead) + new Spark32PlusHoodieVectorizedParquetRecordReader( + convertTz.orNull, + datetimeRebaseSpec.mode.toString, + datetimeRebaseSpec.timeZone, + int96RebaseSpec.mode.toString, + int96RebaseSpec.timeZone, + enableOffHeapColumnVector && taskContext.isDefined, + capacity, + typeChangeInfos) + } else if (HoodieSparkUtils.gteqSpark3_2_1) { + // NOTE: Below code could only be compiled against >= Spark 3.2.1, + // and unfortunately won't compile against Spark 3.2.0 + // However this code is runtime-compatible w/ both Spark 3.2.0 and >= Spark 3.2.1 + val int96RebaseSpec = + DataSourceUtils.int96RebaseSpec(footerFileMetaData.getKeyValueMetaData.get, int96RebaseModeInRead) + val datetimeRebaseSpec = + DataSourceUtils.datetimeRebaseSpec(footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead) + new VectorizedParquetRecordReader( + convertTz.orNull, + datetimeRebaseSpec.mode.toString, + datetimeRebaseSpec.timeZone, + int96RebaseSpec.mode.toString, + int96RebaseSpec.timeZone, + enableOffHeapColumnVector && taskContext.isDefined, + capacity) + } else { + // Spark 3.2.0 + val datetimeRebaseMode = + Spark35DataSourceUtils.datetimeRebaseMode(footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead) + val int96RebaseMode = + Spark35DataSourceUtils.int96RebaseMode(footerFileMetaData.getKeyValueMetaData.get, int96RebaseModeInRead) + createVectorizedParquetRecordReader( + convertTz.orNull, + datetimeRebaseMode.toString, + int96RebaseMode.toString, + enableOffHeapColumnVector && taskContext.isDefined, + capacity) + } + + // SPARK-37089: We cannot register a task completion listener to close this iterator here + // because downstream exec nodes have already registered their listeners. Since listeners + // are executed in reverse order of registration, a listener registered here would close the + // iterator while downstream exec nodes are still running. When off-heap column vectors are + // enabled, this can cause a use-after-free bug leading to a segfault. + // + // Instead, we use FileScanRDD's task completion listener to close this iterator. + val iter = new RecordReaderIterator(vectorizedReader) + try { + vectorizedReader.initialize(split, hadoopAttemptContext) + + // NOTE: We're making appending of the partitioned values to the rows read from the + // data file configurable + if (shouldAppendPartitionValues) { + logDebug(s"Appending $partitionSchema ${file.partitionValues}") + vectorizedReader.initBatch(partitionSchema, file.partitionValues) + } else { + vectorizedReader.initBatch(StructType(Nil), InternalRow.empty) + } + + if (returningBatch) { + vectorizedReader.enableReturningBatches() + } + + // UnsafeRowParquetRecordReader appends the columns internally to avoid another copy. + iter.asInstanceOf[Iterator[InternalRow]] + } catch { + case e: Throwable => + // SPARK-23457: In case there is an exception in initialization, close the iterator to + // avoid leaking resources. + iter.close() + throw e + } + } else { + logDebug(s"Falling back to parquet-mr") + val readSupport = if (HoodieSparkUtils.gteqSpark3_2_1) { + // ParquetRecordReader returns InternalRow + // NOTE: Below code could only be compiled against >= Spark 3.2.1, + // and unfortunately won't compile against Spark 3.2.0 + // However this code is runtime-compatible w/ both Spark 3.2.0 and >= Spark 3.2.1 + val int96RebaseSpec = + DataSourceUtils.int96RebaseSpec(footerFileMetaData.getKeyValueMetaData.get, int96RebaseModeInRead) + val datetimeRebaseSpec = + DataSourceUtils.datetimeRebaseSpec(footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead) + new ParquetReadSupport( + convertTz, + enableVectorizedReader = false, + datetimeRebaseSpec, + int96RebaseSpec) + } else { + val datetimeRebaseMode = + Spark35DataSourceUtils.datetimeRebaseMode(footerFileMetaData.getKeyValueMetaData.get, datetimeRebaseModeInRead) + val int96RebaseMode = + Spark35DataSourceUtils.int96RebaseMode(footerFileMetaData.getKeyValueMetaData.get, int96RebaseModeInRead) + createParquetReadSupport( + convertTz, + /* enableVectorizedReader = */ false, + datetimeRebaseMode, + int96RebaseMode) + } + + val reader = if (pushed.isDefined && enableRecordFilter) { + val parquetFilter = FilterCompat.get(pushed.get, null) + new ParquetRecordReader[InternalRow](readSupport, parquetFilter) + } else { + new ParquetRecordReader[InternalRow](readSupport) + } + val iter = new RecordReaderIterator[InternalRow](reader) + try { + reader.initialize(split, hadoopAttemptContext) + + val fullSchema = DataTypeUtils.toAttributes(requiredSchema) ++ DataTypeUtils.toAttributes(partitionSchema) + val unsafeProjection = if (typeChangeInfos.isEmpty) { + GenerateUnsafeProjection.generate(fullSchema, fullSchema) + } else { + // find type changed. + val newSchema = new StructType(requiredSchema.fields.zipWithIndex.map { case (f, i) => + if (typeChangeInfos.containsKey(i)) { + StructField(f.name, typeChangeInfos.get(i).getRight, f.nullable, f.metadata) + } else f + }) + val newFullSchema = DataTypeUtils.toAttributes(newSchema) ++ DataTypeUtils.toAttributes(partitionSchema) + val castSchema = newFullSchema.zipWithIndex.map { case (attr, i) => + if (typeChangeInfos.containsKey(i)) { + val srcType = typeChangeInfos.get(i).getRight + val dstType = typeChangeInfos.get(i).getLeft + val needTimeZone = Cast.needsTimeZone(srcType, dstType) + Cast(attr, dstType, if (needTimeZone) timeZoneId else None) + } else attr + } + GenerateUnsafeProjection.generate(castSchema, newFullSchema) + } + + // NOTE: We're making appending of the partitioned values to the rows read from the + // data file configurable + if (!shouldAppendPartitionValues || partitionSchema.length == 0) { + // There is no partition columns + iter.map(unsafeProjection) + } else { + val joinedRow = new JoinedRow() + iter.map(d => unsafeProjection(joinedRow(d, file.partitionValues))) + } + } catch { + case e: Throwable => + // SPARK-23457: In case there is an exception in initialization, close the iterator to + // avoid leaking resources. + iter.close() + throw e + } + } + } + } +} + +object Spark35LegacyHoodieParquetFileFormat { + + /** + * NOTE: This method is specific to Spark 3.2.0 + */ + private def createParquetFilters(args: Any*): ParquetFilters = { + // NOTE: ParquetFilters ctor args contain Scala enum, therefore we can't look it + // up by arg types, and have to instead rely on the number of args based on individual class; + // the ctor order is not guaranteed + val ctor = classOf[ParquetFilters].getConstructors.maxBy(_.getParameterCount) + ctor.newInstance(args.map(_.asInstanceOf[AnyRef]): _*) + .asInstanceOf[ParquetFilters] + } + + /** + * NOTE: This method is specific to Spark 3.2.0 + */ + private def createParquetReadSupport(args: Any*): ParquetReadSupport = { + // NOTE: ParquetReadSupport ctor args contain Scala enum, therefore we can't look it + // up by arg types, and have to instead rely on the number of args based on individual class; + // the ctor order is not guaranteed + val ctor = classOf[ParquetReadSupport].getConstructors.maxBy(_.getParameterCount) + ctor.newInstance(args.map(_.asInstanceOf[AnyRef]): _*) + .asInstanceOf[ParquetReadSupport] + } + + /** + * NOTE: This method is specific to Spark 3.2.0 + */ + private def createVectorizedParquetRecordReader(args: Any*): VectorizedParquetRecordReader = { + // NOTE: ParquetReadSupport ctor args contain Scala enum, therefore we can't look it + // up by arg types, and have to instead rely on the number of args based on individual class; + // the ctor order is not guaranteed + val ctor = classOf[VectorizedParquetRecordReader].getConstructors.maxBy(_.getParameterCount) + ctor.newInstance(args.map(_.asInstanceOf[AnyRef]): _*) + .asInstanceOf[VectorizedParquetRecordReader] + } + + def pruneInternalSchema(internalSchemaStr: String, requiredSchema: StructType): String = { + val querySchemaOption = SerDeHelper.fromJson(internalSchemaStr) + if (querySchemaOption.isPresent && requiredSchema.nonEmpty) { + val prunedSchema = SparkInternalSchemaConverter.convertAndPruneStructTypeToInternalSchema(requiredSchema, querySchemaOption.get()) + SerDeHelper.toJson(prunedSchema) + } else { + internalSchemaStr + } + } + + private def rebuildFilterFromParquet(oldFilter: Filter, fileSchema: InternalSchema, querySchema: InternalSchema): Filter = { + if (fileSchema == null || querySchema == null) { + oldFilter + } else { + oldFilter match { + case eq: EqualTo => + val newAttribute = InternalSchemaUtils.reBuildFilterName(eq.attribute, fileSchema, querySchema) + if (newAttribute.isEmpty) AlwaysTrue else eq.copy(attribute = newAttribute) + case eqs: EqualNullSafe => + val newAttribute = InternalSchemaUtils.reBuildFilterName(eqs.attribute, fileSchema, querySchema) + if (newAttribute.isEmpty) AlwaysTrue else eqs.copy(attribute = newAttribute) + case gt: GreaterThan => + val newAttribute = InternalSchemaUtils.reBuildFilterName(gt.attribute, fileSchema, querySchema) + if (newAttribute.isEmpty) AlwaysTrue else gt.copy(attribute = newAttribute) + case gtr: GreaterThanOrEqual => + val newAttribute = InternalSchemaUtils.reBuildFilterName(gtr.attribute, fileSchema, querySchema) + if (newAttribute.isEmpty) AlwaysTrue else gtr.copy(attribute = newAttribute) + case lt: LessThan => + val newAttribute = InternalSchemaUtils.reBuildFilterName(lt.attribute, fileSchema, querySchema) + if (newAttribute.isEmpty) AlwaysTrue else lt.copy(attribute = newAttribute) + case lte: LessThanOrEqual => + val newAttribute = InternalSchemaUtils.reBuildFilterName(lte.attribute, fileSchema, querySchema) + if (newAttribute.isEmpty) AlwaysTrue else lte.copy(attribute = newAttribute) + case i: In => + val newAttribute = InternalSchemaUtils.reBuildFilterName(i.attribute, fileSchema, querySchema) + if (newAttribute.isEmpty) AlwaysTrue else i.copy(attribute = newAttribute) + case isn: IsNull => + val newAttribute = InternalSchemaUtils.reBuildFilterName(isn.attribute, fileSchema, querySchema) + if (newAttribute.isEmpty) AlwaysTrue else isn.copy(attribute = newAttribute) + case isnn: IsNotNull => + val newAttribute = InternalSchemaUtils.reBuildFilterName(isnn.attribute, fileSchema, querySchema) + if (newAttribute.isEmpty) AlwaysTrue else isnn.copy(attribute = newAttribute) + case And(left, right) => + And(rebuildFilterFromParquet(left, fileSchema, querySchema), rebuildFilterFromParquet(right, fileSchema, querySchema)) + case Or(left, right) => + Or(rebuildFilterFromParquet(left, fileSchema, querySchema), rebuildFilterFromParquet(right, fileSchema, querySchema)) + case Not(child) => + Not(rebuildFilterFromParquet(child, fileSchema, querySchema)) + case ssw: StringStartsWith => + val newAttribute = InternalSchemaUtils.reBuildFilterName(ssw.attribute, fileSchema, querySchema) + if (newAttribute.isEmpty) AlwaysTrue else ssw.copy(attribute = newAttribute) + case ses: StringEndsWith => + val newAttribute = InternalSchemaUtils.reBuildFilterName(ses.attribute, fileSchema, querySchema) + if (newAttribute.isEmpty) AlwaysTrue else ses.copy(attribute = newAttribute) + case sc: StringContains => + val newAttribute = InternalSchemaUtils.reBuildFilterName(sc.attribute, fileSchema, querySchema) + if (newAttribute.isEmpty) AlwaysTrue else sc.copy(attribute = newAttribute) + case AlwaysTrue => + AlwaysTrue + case AlwaysFalse => + AlwaysFalse + case _ => + AlwaysTrue + } + } + } +} diff --git a/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/hudi/Spark35ResolveHudiAlterTableCommand.scala b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/hudi/Spark35ResolveHudiAlterTableCommand.scala new file mode 100644 index 0000000000000..160804f62b370 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/hudi/Spark35ResolveHudiAlterTableCommand.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hudi + +import org.apache.hudi.common.config.HoodieCommonConfig +import org.apache.hudi.internal.schema.action.TableChange.ColumnChangeID +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.analysis.ResolvedTable +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.hudi.catalog.HoodieInternalV2Table +import org.apache.spark.sql.hudi.command.{AlterTableCommand => HudiAlterTableCommand} + +/** + * Rule to mostly resolve, normalize and rewrite column names based on case sensitivity. + * for alter table column commands. + */ +class Spark35ResolveHudiAlterTableCommand(sparkSession: SparkSession) extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = { + if (schemaEvolutionEnabled) { + plan.resolveOperatorsUp { + case set@SetTableProperties(ResolvedHoodieV2TablePlan(t), _) if set.resolved => + HudiAlterTableCommand(t.v1Table, set.changes, ColumnChangeID.PROPERTY_CHANGE) + case unSet@UnsetTableProperties(ResolvedHoodieV2TablePlan(t), _, _) if unSet.resolved => + HudiAlterTableCommand(t.v1Table, unSet.changes, ColumnChangeID.PROPERTY_CHANGE) + case drop@DropColumns(ResolvedHoodieV2TablePlan(t), _, _) if drop.resolved => + HudiAlterTableCommand(t.v1Table, drop.changes, ColumnChangeID.DELETE) + case add@AddColumns(ResolvedHoodieV2TablePlan(t), _) if add.resolved => + HudiAlterTableCommand(t.v1Table, add.changes, ColumnChangeID.ADD) + case renameColumn@RenameColumn(ResolvedHoodieV2TablePlan(t), _, _) if renameColumn.resolved => + HudiAlterTableCommand(t.v1Table, renameColumn.changes, ColumnChangeID.UPDATE) + case alter@AlterColumn(ResolvedHoodieV2TablePlan(t), _, _, _, _, _, _) if alter.resolved => + HudiAlterTableCommand(t.v1Table, alter.changes, ColumnChangeID.UPDATE) + case replace@ReplaceColumns(ResolvedHoodieV2TablePlan(t), _) if replace.resolved => + HudiAlterTableCommand(t.v1Table, replace.changes, ColumnChangeID.REPLACE) + } + } else { + plan + } + } + + private def schemaEvolutionEnabled: Boolean = + sparkSession.sessionState.conf.getConfString(HoodieCommonConfig.SCHEMA_EVOLUTION_ENABLE.key, + HoodieCommonConfig.SCHEMA_EVOLUTION_ENABLE.defaultValue.toString).toBoolean + + object ResolvedHoodieV2TablePlan { + def unapply(plan: LogicalPlan): Option[HoodieInternalV2Table] = { + plan match { + case ResolvedTable(_, _, v2Table: HoodieInternalV2Table, _) => Some(v2Table) + case _ => None + } + } + } +} + diff --git a/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSpark35Analysis.scala b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSpark35Analysis.scala new file mode 100644 index 0000000000000..f137c9dea6c30 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSpark35Analysis.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hudi.analysis + +import org.apache.hudi.DefaultSource + +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.hudi.ProvidesHoodieConfig +import org.apache.spark.sql.hudi.catalog.HoodieInternalV2Table +import org.apache.spark.sql.{SQLContext, SparkSession} + +/** + * NOTE: PLEASE READ CAREFULLY + * + * Since Hudi relations don't currently implement DS V2 Read API, we have to fallback to V1 here. + * Such fallback will have considerable performance impact, therefore it's only performed in cases + * where V2 API have to be used. Currently only such use-case is using of Schema Evolution feature + * + * Check out HUDI-4178 for more details + */ +case class HoodieSpark35DataSourceV2ToV1Fallback(sparkSession: SparkSession) extends Rule[LogicalPlan] + with ProvidesHoodieConfig { + + override def apply(plan: LogicalPlan): LogicalPlan = plan match { + // The only place we're avoiding fallback is in [[AlterTableCommand]]s since + // current implementation relies on DSv2 features + case _: AlterTableCommand => plan + + // NOTE: Unfortunately, [[InsertIntoStatement]] is implemented in a way that doesn't expose + // target relation as a child (even though there's no good reason for that) + case iis@InsertIntoStatement(rv2@DataSourceV2Relation(v2Table: HoodieInternalV2Table, _, _, _, _), _, _, _, _, _, _) => + iis.copy(table = convertToV1(rv2, v2Table)) + + case _ => + plan.resolveOperatorsDown { + case rv2@DataSourceV2Relation(v2Table: HoodieInternalV2Table, _, _, _, _) => convertToV1(rv2, v2Table) + } + } + + private def convertToV1(rv2: DataSourceV2Relation, v2Table: HoodieInternalV2Table) = { + val output = rv2.output + val catalogTable = v2Table.catalogTable.map(_ => v2Table.v1Table) + val relation = new DefaultSource().createRelation(new SQLContext(sparkSession), + buildHoodieConfig(v2Table.hoodieCatalogTable), v2Table.hoodieCatalogTable.tableSchema) + + LogicalRelation(relation, output, catalogTable, isStreaming = false) + } +} diff --git a/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_5ExtendedSqlAstBuilder.scala b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_5ExtendedSqlAstBuilder.scala new file mode 100644 index 0000000000000..cb6720ea3586b --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_5ExtendedSqlAstBuilder.scala @@ -0,0 +1,3496 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.parser + +import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode} +import org.antlr.v4.runtime.{ParserRuleContext, Token} +import org.apache.hudi.spark.sql.parser.HoodieSqlBaseParser._ +import org.apache.hudi.spark.sql.parser.{HoodieSqlBaseBaseVisitor, HoodieSqlBaseParser} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} +import org.apache.spark.sql.catalyst.parser.ParserUtils.{checkDuplicateClauses, checkDuplicateKeys, entry, escapedIdentifier, operationNotAllowed, source, string, stringWithoutUnescape, validate, withOrigin} +import org.apache.spark.sql.catalyst.parser.{EnhancedLogicalPlan, ParseException, ParserInterface} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.catalyst.util.DateTimeUtils._ +import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils, IntervalUtils, truncatedString} +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.BucketSpecHelper +import org.apache.spark.sql.connector.catalog.TableCatalog +import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition +import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform, Expression => V2Expression} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.util.Utils.isTesting +import org.apache.spark.util.random.RandomSampler + +import java.util.Locale +import java.util.concurrent.TimeUnit +import javax.xml.bind.DatatypeConverter +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +/** + * The AstBuilder for HoodieSqlParser to parser the AST tree to Logical Plan. + * Here we only do the parser for the extended sql syntax. e.g MergeInto. For + * other sql syntax we use the delegate sql parser which is the SparkSqlParser. + */ +class HoodieSpark3_5ExtendedSqlAstBuilder(conf: SQLConf, delegate: ParserInterface) + extends HoodieSqlBaseBaseVisitor[AnyRef] with Logging { + + protected def typedVisit[T](ctx: ParseTree): T = { + ctx.accept(this).asInstanceOf[T] + } + + /** + * Override the default behavior for all visit methods. This will only return a non-null result + * when the context has only one child. This is done because there is no generic method to + * combine the results of the context children. In all other cases null is returned. + */ + override def visitChildren(node: RuleNode): AnyRef = { + if (node.getChildCount == 1) { + node.getChild(0).accept(this) + } else { + null + } + } + + /** + * Create an aliased table reference. This is typically used in FROM clauses. + */ + override def visitTableName(ctx: TableNameContext): LogicalPlan = withOrigin(ctx) { + val tableId = visitMultipartIdentifier(ctx.multipartIdentifier()) + val relation = UnresolvedRelation(tableId) + val table = mayApplyAliasPlan( + ctx.tableAlias, relation.optionalMap(ctx.temporalClause)(withTimeTravel)) + table.optionalMap(ctx.sample)(withSample) + } + + private def withTimeTravel( + ctx: TemporalClauseContext, plan: LogicalPlan): LogicalPlan = withOrigin(ctx) { + val v = ctx.version + val version = if (ctx.INTEGER_VALUE != null) { + Some(v.getText) + } else { + Option(v).map(string) + } + + val timestamp = Option(ctx.timestamp).map(expression) + if (timestamp.exists(_.references.nonEmpty)) { + throw new ParseException( + "timestamp expression cannot refer to any columns", ctx.timestamp) + } + if (timestamp.exists(e => SubqueryExpression.hasSubquery(e))) { + throw new ParseException( + "timestamp expression cannot contain subqueries", ctx.timestamp) + } + + TimeTravelRelation(plan, timestamp, version) + } + + // ============== The following code is fork from org.apache.spark.sql.catalyst.parser.AstBuilder + override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) { + visit(ctx.statement).asInstanceOf[LogicalPlan] + } + + override def visitSingleExpression(ctx: SingleExpressionContext): Expression = withOrigin(ctx) { + visitNamedExpression(ctx.namedExpression) + } + + override def visitSingleTableIdentifier( + ctx: SingleTableIdentifierContext): TableIdentifier = withOrigin(ctx) { + visitTableIdentifier(ctx.tableIdentifier) + } + + override def visitSingleFunctionIdentifier( + ctx: SingleFunctionIdentifierContext): FunctionIdentifier = withOrigin(ctx) { + visitFunctionIdentifier(ctx.functionIdentifier) + } + + override def visitSingleMultipartIdentifier( + ctx: SingleMultipartIdentifierContext): Seq[String] = withOrigin(ctx) { + visitMultipartIdentifier(ctx.multipartIdentifier) + } + + override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) { + typedVisit[DataType](ctx.dataType) + } + + override def visitSingleTableSchema(ctx: SingleTableSchemaContext): StructType = { + val schema = StructType(visitColTypeList(ctx.colTypeList)) + withOrigin(ctx)(schema) + } + + /* ******************************************************************************************** + * Plan parsing + * ******************************************************************************************** */ + protected def plan(tree: ParserRuleContext): LogicalPlan = typedVisit(tree) + + /** + * Create a top-level plan with Common Table Expressions. + */ + override def visitQuery(ctx: QueryContext): LogicalPlan = withOrigin(ctx) { + val query = plan(ctx.queryTerm).optionalMap(ctx.queryOrganization)(withQueryResultClauses) + + // Apply CTEs + query.optionalMap(ctx.ctes)(withCTE) + } + + override def visitDmlStatement(ctx: DmlStatementContext): AnyRef = withOrigin(ctx) { + val dmlStmt = plan(ctx.dmlStatementNoWith) + // Apply CTEs + dmlStmt.optionalMap(ctx.ctes)(withCTE) + } + + private def withCTE(ctx: CtesContext, plan: LogicalPlan): LogicalPlan = { + val ctes = ctx.namedQuery.asScala.map { nCtx => + val namedQuery = visitNamedQuery(nCtx) + (namedQuery.alias, namedQuery) + } + // Check for duplicate names. + val duplicates = ctes.groupBy(_._1).filter(_._2.size > 1).keys + if (duplicates.nonEmpty) { + throw new ParseException(s"CTE definition can't have duplicate names: ${duplicates.mkString("'", "', '", "'")}.", ctx) + } + UnresolvedWith(plan, ctes.toSeq) + } + + /** + * Create a logical query plan for a hive-style FROM statement body. + */ + private def withFromStatementBody( + ctx: FromStatementBodyContext, plan: LogicalPlan): LogicalPlan = withOrigin(ctx) { + // two cases for transforms and selects + if (ctx.transformClause != null) { + withTransformQuerySpecification( + ctx, + ctx.transformClause, + ctx.lateralView, + ctx.whereClause, + ctx.aggregationClause, + ctx.havingClause, + ctx.windowClause, + plan + ) + } else { + withSelectQuerySpecification( + ctx, + ctx.selectClause, + ctx.lateralView, + ctx.whereClause, + ctx.aggregationClause, + ctx.havingClause, + ctx.windowClause, + plan + ) + } + } + + override def visitFromStatement(ctx: FromStatementContext): LogicalPlan = withOrigin(ctx) { + val from = visitFromClause(ctx.fromClause) + val selects = ctx.fromStatementBody.asScala.map { body => + withFromStatementBody(body, from). + // Add organization statements. + optionalMap(body.queryOrganization)(withQueryResultClauses) + } + // If there are multiple SELECT just UNION them together into one query. + if (selects.length == 1) { + selects.head + } else { + Union(selects.toSeq) + } + } + + /** + * Create a named logical plan. + * + * This is only used for Common Table Expressions. + */ + override def visitNamedQuery(ctx: NamedQueryContext): SubqueryAlias = withOrigin(ctx) { + val subQuery: LogicalPlan = plan(ctx.query).optionalMap(ctx.columnAliases)( + (columnAliases, plan) => + UnresolvedSubqueryColumnAliases(visitIdentifierList(columnAliases), plan) + ) + SubqueryAlias(ctx.name.getText, subQuery) + } + + /** + * Create a logical plan which allows for multiple inserts using one 'from' statement. These + * queries have the following SQL form: + * {{{ + * [WITH cte...]? + * FROM src + * [INSERT INTO tbl1 SELECT *]+ + * }}} + * For example: + * {{{ + * FROM db.tbl1 A + * INSERT INTO dbo.tbl1 SELECT * WHERE A.value = 10 LIMIT 5 + * INSERT INTO dbo.tbl2 SELECT * WHERE A.value = 12 + * }}} + * This (Hive) feature cannot be combined with set-operators. + */ + override def visitMultiInsertQuery(ctx: MultiInsertQueryContext): LogicalPlan = withOrigin(ctx) { + val from = visitFromClause(ctx.fromClause) + + // Build the insert clauses. + val inserts = ctx.multiInsertQueryBody.asScala.map { body => + withInsertInto(body.insertInto, + withFromStatementBody(body.fromStatementBody, from). + optionalMap(body.fromStatementBody.queryOrganization)(withQueryResultClauses)) + } + + // If there are multiple INSERTS just UNION them together into one query. + if (inserts.length == 1) { + inserts.head + } else { + Union(inserts.toSeq) + } + } + + /** + * Create a logical plan for a regular (single-insert) query. + */ + override def visitSingleInsertQuery( + ctx: SingleInsertQueryContext): LogicalPlan = withOrigin(ctx) { + withInsertInto( + ctx.insertInto(), + plan(ctx.queryTerm).optionalMap(ctx.queryOrganization)(withQueryResultClauses)) + } + + /** + * Parameters used for writing query to a table: + * (UnresolvedRelation, tableColumnList, partitionKeys, ifPartitionNotExists). + */ + type InsertTableParams = (UnresolvedRelation, Seq[String], Map[String, Option[String]], Boolean) + + /** + * Parameters used for writing query to a directory: (isLocal, CatalogStorageFormat, provider). + */ + type InsertDirParams = (Boolean, CatalogStorageFormat, Option[String]) + + /** + * Add an + * {{{ + * INSERT OVERWRITE TABLE tableIdentifier [partitionSpec [IF NOT EXISTS]]? [identifierList] + * INSERT INTO [TABLE] tableIdentifier [partitionSpec] [identifierList] + * INSERT OVERWRITE [LOCAL] DIRECTORY STRING [rowFormat] [createFileFormat] + * INSERT OVERWRITE [LOCAL] DIRECTORY [STRING] tableProvider [OPTIONS tablePropertyList] + * }}} + * operation to logical plan + */ + private def withInsertInto( + ctx: InsertIntoContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + ctx match { + case table: InsertIntoTableContext => + val (relation, cols, partition, ifPartitionNotExists) = visitInsertIntoTable(table) + InsertIntoStatement( + relation, + partition, + cols, + query, + overwrite = false, + ifPartitionNotExists) + case table: InsertOverwriteTableContext => + val (relation, cols, partition, ifPartitionNotExists) = visitInsertOverwriteTable(table) + InsertIntoStatement( + relation, + partition, + cols, + query, + overwrite = true, + ifPartitionNotExists) + case dir: InsertOverwriteDirContext => + val (isLocal, storage, provider) = visitInsertOverwriteDir(dir) + InsertIntoDir(isLocal, storage, provider, query, overwrite = true) + case hiveDir: InsertOverwriteHiveDirContext => + val (isLocal, storage, provider) = visitInsertOverwriteHiveDir(hiveDir) + InsertIntoDir(isLocal, storage, provider, query, overwrite = true) + case _ => + throw new ParseException("Invalid InsertIntoContext", ctx) + } + } + + /** + * Add an INSERT INTO TABLE operation to the logical plan. + */ + override def visitInsertIntoTable( + ctx: InsertIntoTableContext): InsertTableParams = withOrigin(ctx) { + val cols = Option(ctx.identifierList()).map(visitIdentifierList).getOrElse(Nil) + val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty) + + if (ctx.EXISTS != null) { + operationNotAllowed("INSERT INTO ... IF NOT EXISTS", ctx) + } + + (createUnresolvedRelation(ctx.multipartIdentifier), cols, partitionKeys, false) + } + + /** + * Add an INSERT OVERWRITE TABLE operation to the logical plan. + */ + override def visitInsertOverwriteTable( + ctx: InsertOverwriteTableContext): InsertTableParams = withOrigin(ctx) { + assert(ctx.OVERWRITE() != null) + val cols = Option(ctx.identifierList()).map(visitIdentifierList).getOrElse(Nil) + val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty) + + val dynamicPartitionKeys: Map[String, Option[String]] = partitionKeys.filter(_._2.isEmpty) + if (ctx.EXISTS != null && dynamicPartitionKeys.nonEmpty) { + operationNotAllowed("IF NOT EXISTS with dynamic partitions: " + + dynamicPartitionKeys.keys.mkString(", "), ctx) + } + + (createUnresolvedRelation(ctx.multipartIdentifier), cols, partitionKeys, ctx.EXISTS() != null) + } + + /** + * Write to a directory, returning a [[InsertIntoDir]] logical plan. + */ + override def visitInsertOverwriteDir( + ctx: InsertOverwriteDirContext): InsertDirParams = withOrigin(ctx) { + throw new ParseException("INSERT OVERWRITE DIRECTORY is not supported", ctx) + } + + /** + * Write to a directory, returning a [[InsertIntoDir]] logical plan. + */ + override def visitInsertOverwriteHiveDir( + ctx: InsertOverwriteHiveDirContext): InsertDirParams = withOrigin(ctx) { + throw new ParseException("INSERT OVERWRITE DIRECTORY is not supported", ctx) + } + + private def getTableAliasWithoutColumnAlias( + ctx: TableAliasContext, op: String): Option[String] = { + if (ctx == null) { + None + } else { + val ident = ctx.strictIdentifier() + if (ctx.identifierList() != null) { + throw new ParseException(s"Columns aliases are not allowed in $op.", ctx.identifierList()) + } + if (ident != null) Some(ident.getText) else None + } + } + + override def visitDeleteFromTable( + ctx: DeleteFromTableContext): LogicalPlan = withOrigin(ctx) { + val table = createUnresolvedRelation(ctx.multipartIdentifier()) + val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "DELETE") + val aliasedTable = tableAlias.map(SubqueryAlias(_, table)).getOrElse(table) + val predicate = if (ctx.whereClause() != null) { + Some(expression(ctx.whereClause().booleanExpression())) + } else { + None + } + DeleteFromTable(aliasedTable, predicate.get) + } + + override def visitUpdateTable(ctx: UpdateTableContext): LogicalPlan = withOrigin(ctx) { + val table = createUnresolvedRelation(ctx.multipartIdentifier()) + val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "UPDATE") + val aliasedTable = tableAlias.map(SubqueryAlias(_, table)).getOrElse(table) + val assignments = withAssignments(ctx.setClause().assignmentList()) + val predicate = if (ctx.whereClause() != null) { + Some(expression(ctx.whereClause().booleanExpression())) + } else { + None + } + + UpdateTable(aliasedTable, assignments, predicate) + } + + private def withAssignments(assignCtx: AssignmentListContext): Seq[Assignment] = + withOrigin(assignCtx) { + assignCtx.assignment().asScala.map { assign => + Assignment(UnresolvedAttribute(visitMultipartIdentifier(assign.key)), + expression(assign.value)) + }.toSeq + } + + override def visitMergeIntoTable(ctx: MergeIntoTableContext): LogicalPlan = withOrigin(ctx) { + val targetTable = createUnresolvedRelation(ctx.target) + val targetTableAlias = getTableAliasWithoutColumnAlias(ctx.targetAlias, "MERGE") + val aliasedTarget = targetTableAlias.map(SubqueryAlias(_, targetTable)).getOrElse(targetTable) + + val sourceTableOrQuery = if (ctx.source != null) { + createUnresolvedRelation(ctx.source) + } else if (ctx.sourceQuery != null) { + visitQuery(ctx.sourceQuery) + } else { + throw new ParseException("Empty source for merge: you should specify a source" + + " table/subquery in merge.", ctx.source) + } + val sourceTableAlias = getTableAliasWithoutColumnAlias(ctx.sourceAlias, "MERGE") + val aliasedSource = + sourceTableAlias.map(SubqueryAlias(_, sourceTableOrQuery)).getOrElse(sourceTableOrQuery) + + val mergeCondition = expression(ctx.mergeCondition) + + val matchedActions = ctx.matchedClause().asScala.map { + clause => { + if (clause.matchedAction().DELETE() != null) { + DeleteAction(Option(clause.matchedCond).map(expression)) + } else if (clause.matchedAction().UPDATE() != null) { + val condition = Option(clause.matchedCond).map(expression) + if (clause.matchedAction().ASTERISK() != null) { + UpdateStarAction(condition) + } else { + UpdateAction(condition, withAssignments(clause.matchedAction().assignmentList())) + } + } else { + // It should not be here. + throw new ParseException(s"Unrecognized matched action: ${clause.matchedAction().getText}", + clause.matchedAction()) + } + } + } + val notMatchedActions = ctx.notMatchedClause().asScala.map { + clause => { + if (clause.notMatchedAction().INSERT() != null) { + val condition = Option(clause.notMatchedCond).map(expression) + if (clause.notMatchedAction().ASTERISK() != null) { + InsertStarAction(condition) + } else { + val columns = clause.notMatchedAction().columns.multipartIdentifier() + .asScala.map(attr => UnresolvedAttribute(visitMultipartIdentifier(attr))) + val values = clause.notMatchedAction().expression().asScala.map(expression) + if (columns.size != values.size) { + throw new ParseException("The number of inserted values cannot match the fields.", + clause.notMatchedAction()) + } + InsertAction(condition, columns.zip(values).map(kv => Assignment(kv._1, kv._2)).toSeq) + } + } else { + // It should not be here. + throw new ParseException(s"Unrecognized not matched action: ${clause.notMatchedAction().getText}", + clause.notMatchedAction()) + } + } + } + if (matchedActions.isEmpty && notMatchedActions.isEmpty) { + throw new ParseException("There must be at least one WHEN clause in a MERGE statement", ctx) + } + // children being empty means that the condition is not set + val matchedActionSize = matchedActions.length + if (matchedActionSize >= 2 && !matchedActions.init.forall(_.condition.nonEmpty)) { + throw new ParseException("When there are more than one MATCHED clauses in a MERGE " + + "statement, only the last MATCHED clause can omit the condition.", ctx) + } + val notMatchedActionSize = notMatchedActions.length + if (notMatchedActionSize >= 2 && !notMatchedActions.init.forall(_.condition.nonEmpty)) { + throw new ParseException("When there are more than one NOT MATCHED clauses in a MERGE " + + "statement, only the last NOT MATCHED clause can omit the condition.", ctx) + } + + MergeIntoTable( + aliasedTarget, + aliasedSource, + mergeCondition, + matchedActions.toSeq, + notMatchedActions.toSeq, + Seq.empty) + } + + /** + * Create a partition specification map. + */ + override def visitPartitionSpec( + ctx: PartitionSpecContext): Map[String, Option[String]] = withOrigin(ctx) { + val legacyNullAsString = + conf.getConf(SQLConf.LEGACY_PARSE_NULL_PARTITION_SPEC_AS_STRING_LITERAL) + val parts = ctx.partitionVal.asScala.map { pVal => + val name = pVal.identifier.getText + val value = Option(pVal.constant).map(v => visitStringConstant(v, legacyNullAsString)) + name -> value + } + // Before calling `toMap`, we check duplicated keys to avoid silently ignore partition values + // in partition spec like PARTITION(a='1', b='2', a='3'). The real semantical check for + // partition columns will be done in analyzer. + if (conf.caseSensitiveAnalysis) { + checkDuplicateKeys(parts.toSeq, ctx) + } else { + checkDuplicateKeys(parts.map(kv => kv._1.toLowerCase(Locale.ROOT) -> kv._2).toSeq, ctx) + } + parts.toMap + } + + /** + * Create a partition specification map without optional values. + */ + protected def visitNonOptionalPartitionSpec( + ctx: PartitionSpecContext): Map[String, String] = withOrigin(ctx) { + visitPartitionSpec(ctx).map { + case (key, None) => throw new ParseException(s"Found an empty partition key '$key'.", ctx) + case (key, Some(value)) => key -> value + } + } + + /** + * Convert a constant of any type into a string. This is typically used in DDL commands, and its + * main purpose is to prevent slight differences due to back to back conversions i.e.: + * String -> Literal -> String. + */ + protected def visitStringConstant( + ctx: ConstantContext, + legacyNullAsString: Boolean): String = withOrigin(ctx) { + expression(ctx) match { + case Literal(null, _) if !legacyNullAsString => null + case l@Literal(null, _) => l.toString + case l: Literal => + // TODO For v2 commands, we will cast the string back to its actual value, + // which is a waste and can be improved in the future. + Cast(l, StringType, Some(conf.sessionLocalTimeZone)).eval().toString + case other => + throw new IllegalArgumentException(s"Only literals are allowed in the " + + s"partition spec, but got ${other.sql}") + } + } + + /** + * Add ORDER BY/SORT BY/CLUSTER BY/DISTRIBUTE BY/LIMIT/WINDOWS clauses to the logical plan. These + * clauses determine the shape (ordering/partitioning/rows) of the query result. + */ + private def withQueryResultClauses( + ctx: QueryOrganizationContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + import ctx._ + + // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause. + val withOrder = if ( + !order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { + // ORDER BY ... + Sort(order.asScala.map(visitSortItem).toSeq, global = true, query) + } else if (order.isEmpty && !sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { + // SORT BY ... + Sort(sort.asScala.map(visitSortItem).toSeq, global = false, query) + } else if (order.isEmpty && sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { + // DISTRIBUTE BY ... + withRepartitionByExpression(ctx, expressionList(distributeBy), query) + } else if (order.isEmpty && !sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { + // SORT BY ... DISTRIBUTE BY ... + Sort( + sort.asScala.map(visitSortItem).toSeq, + global = false, + withRepartitionByExpression(ctx, expressionList(distributeBy), query)) + } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && !clusterBy.isEmpty) { + // CLUSTER BY ... + val expressions = expressionList(clusterBy) + Sort( + expressions.map(SortOrder(_, Ascending)), + global = false, + withRepartitionByExpression(ctx, expressions, query)) + } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { + // [EMPTY] + query + } else { + throw new ParseException( + "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", ctx) + } + + // WINDOWS + val withWindow = withOrder.optionalMap(windowClause)(withWindowClause) + + // LIMIT + // - LIMIT ALL is the same as omitting the LIMIT clause + withWindow.optional(limit) { + Limit(typedVisit(limit), withWindow) + } + } + + /** + * Create a clause for DISTRIBUTE BY. + */ + protected def withRepartitionByExpression( + ctx: QueryOrganizationContext, + expressions: Seq[Expression], + query: LogicalPlan): LogicalPlan = { + RepartitionByExpression(expressions, query, None) + } + + override def visitTransformQuerySpecification( + ctx: TransformQuerySpecificationContext): LogicalPlan = withOrigin(ctx) { + val from = OneRowRelation().optional(ctx.fromClause) { + visitFromClause(ctx.fromClause) + } + withTransformQuerySpecification( + ctx, + ctx.transformClause, + ctx.lateralView, + ctx.whereClause, + ctx.aggregationClause, + ctx.havingClause, + ctx.windowClause, + from + ) + } + + override def visitRegularQuerySpecification( + ctx: RegularQuerySpecificationContext): LogicalPlan = withOrigin(ctx) { + val from = OneRowRelation().optional(ctx.fromClause) { + visitFromClause(ctx.fromClause) + } + withSelectQuerySpecification( + ctx, + ctx.selectClause, + ctx.lateralView, + ctx.whereClause, + ctx.aggregationClause, + ctx.havingClause, + ctx.windowClause, + from + ) + } + + override def visitNamedExpressionSeq( + ctx: NamedExpressionSeqContext): Seq[Expression] = { + Option(ctx).toSeq + .flatMap(_.namedExpression.asScala) + .map(typedVisit[Expression]) + } + + override def visitExpressionSeq(ctx: ExpressionSeqContext): Seq[Expression] = { + Option(ctx).toSeq + .flatMap(_.expression.asScala) + .map(typedVisit[Expression]) + } + + /** + * Create a logical plan using a having clause. + */ + private def withHavingClause( + ctx: HavingClauseContext, plan: LogicalPlan): LogicalPlan = { + // Note that we add a cast to non-predicate expressions. If the expression itself is + // already boolean, the optimizer will get rid of the unnecessary cast. + val predicate = expression(ctx.booleanExpression) match { + case p: Predicate => p + case e => Cast(e, BooleanType) + } + UnresolvedHaving(predicate, plan) + } + + /** + * Create a logical plan using a where clause. + */ + private def withWhereClause(ctx: WhereClauseContext, plan: LogicalPlan): LogicalPlan = { + Filter(expression(ctx.booleanExpression), plan) + } + + /** + * Add a hive-style transform (SELECT TRANSFORM/MAP/REDUCE) query specification to a logical plan. + */ + private def withTransformQuerySpecification( + ctx: ParserRuleContext, + transformClause: TransformClauseContext, + lateralView: java.util.List[LateralViewContext], + whereClause: WhereClauseContext, + aggregationClause: AggregationClauseContext, + havingClause: HavingClauseContext, + windowClause: WindowClauseContext, + relation: LogicalPlan): LogicalPlan = withOrigin(ctx) { + if (transformClause.setQuantifier != null) { + throw new ParseException("TRANSFORM does not support DISTINCT/ALL in inputs", transformClause.setQuantifier) + } + // Create the attributes. + val (attributes, schemaLess) = if (transformClause.colTypeList != null) { + // Typed return columns. + (DataTypeUtils.toAttributes(createSchema(transformClause.colTypeList)), false) + } else if (transformClause.identifierSeq != null) { + // Untyped return columns. + val attrs = visitIdentifierSeq(transformClause.identifierSeq).map { name => + AttributeReference(name, StringType, nullable = true)() + } + (attrs, false) + } else { + (Seq(AttributeReference("key", StringType)(), + AttributeReference("value", StringType)()), true) + } + + val plan = visitCommonSelectQueryClausePlan( + relation, + visitExpressionSeq(transformClause.expressionSeq), + lateralView, + whereClause, + aggregationClause, + havingClause, + windowClause, + isDistinct = false) + + ScriptTransformation( + string(transformClause.script), + attributes, + plan, + withScriptIOSchema( + ctx, + transformClause.inRowFormat, + transformClause.recordWriter, + transformClause.outRowFormat, + transformClause.recordReader, + schemaLess + ) + ) + } + + /** + * Add a regular (SELECT) query specification to a logical plan. The query specification + * is the core of the logical plan, this is where sourcing (FROM clause), projection (SELECT), + * aggregation (GROUP BY ... HAVING ...) and filtering (WHERE) takes place. + * + * Note that query hints are ignored (both by the parser and the builder). + */ + private def withSelectQuerySpecification( + ctx: ParserRuleContext, + selectClause: SelectClauseContext, + lateralView: java.util.List[LateralViewContext], + whereClause: WhereClauseContext, + aggregationClause: AggregationClauseContext, + havingClause: HavingClauseContext, + windowClause: WindowClauseContext, + relation: LogicalPlan): LogicalPlan = withOrigin(ctx) { + val isDistinct = selectClause.setQuantifier() != null && + selectClause.setQuantifier().DISTINCT() != null + + val plan = visitCommonSelectQueryClausePlan( + relation, + visitNamedExpressionSeq(selectClause.namedExpressionSeq), + lateralView, + whereClause, + aggregationClause, + havingClause, + windowClause, + isDistinct) + + // Hint + selectClause.hints.asScala.foldRight(plan)(withHints) + } + + def visitCommonSelectQueryClausePlan( + relation: LogicalPlan, + expressions: Seq[Expression], + lateralView: java.util.List[LateralViewContext], + whereClause: WhereClauseContext, + aggregationClause: AggregationClauseContext, + havingClause: HavingClauseContext, + windowClause: WindowClauseContext, + isDistinct: Boolean): LogicalPlan = { + // Add lateral views. + val withLateralView = lateralView.asScala.foldLeft(relation)(withGenerate) + + // Add where. + val withFilter = withLateralView.optionalMap(whereClause)(withWhereClause) + + // Add aggregation or a project. + val namedExpressions = expressions.map { + case e: NamedExpression => e + case e: Expression => UnresolvedAlias(e) + } + + def createProject() = if (namedExpressions.nonEmpty) { + Project(namedExpressions, withFilter) + } else { + withFilter + } + + val withProject = if (aggregationClause == null && havingClause != null) { + if (conf.getConf(SQLConf.LEGACY_HAVING_WITHOUT_GROUP_BY_AS_WHERE)) { + // If the legacy conf is set, treat HAVING without GROUP BY as WHERE. + val predicate = expression(havingClause.booleanExpression) match { + case p: Predicate => p + case e => Cast(e, BooleanType) + } + Filter(predicate, createProject()) + } else { + // According to SQL standard, HAVING without GROUP BY means global aggregate. + withHavingClause(havingClause, Aggregate(Nil, namedExpressions, withFilter)) + } + } else if (aggregationClause != null) { + val aggregate = withAggregationClause(aggregationClause, namedExpressions, withFilter) + aggregate.optionalMap(havingClause)(withHavingClause) + } else { + // When hitting this branch, `having` must be null. + createProject() + } + + // Distinct + val withDistinct = if (isDistinct) { + Distinct(withProject) + } else { + withProject + } + + // Window + val withWindow = withDistinct.optionalMap(windowClause)(withWindowClause) + + withWindow + } + + // Script Transform's input/output format. + type ScriptIOFormat = + (Seq[(String, String)], Option[String], Seq[(String, String)], Option[String]) + + protected def getRowFormatDelimited(ctx: RowFormatDelimitedContext): ScriptIOFormat = { + // TODO we should use the visitRowFormatDelimited function here. However HiveScriptIOSchema + // expects a seq of pairs in which the old parsers' token names are used as keys. + // Transforming the result of visitRowFormatDelimited would be quite a bit messier than + // retrieving the key value pairs ourselves. + val entries = entry("TOK_TABLEROWFORMATFIELD", ctx.fieldsTerminatedBy) ++ + entry("TOK_TABLEROWFORMATCOLLITEMS", ctx.collectionItemsTerminatedBy) ++ + entry("TOK_TABLEROWFORMATMAPKEYS", ctx.keysTerminatedBy) ++ + entry("TOK_TABLEROWFORMATNULL", ctx.nullDefinedAs) ++ + Option(ctx.linesSeparatedBy).toSeq.map { token => + val value = string(token) + validate( + value == "\n", + s"LINES TERMINATED BY only supports newline '\\n' right now: $value", + ctx) + "TOK_TABLEROWFORMATLINES" -> value + } + + (entries, None, Seq.empty, None) + } + + /** + * Create a [[ScriptInputOutputSchema]]. + */ + protected def withScriptIOSchema( + ctx: ParserRuleContext, + inRowFormat: RowFormatContext, + recordWriter: Token, + outRowFormat: RowFormatContext, + recordReader: Token, + schemaLess: Boolean): ScriptInputOutputSchema = { + + def format(fmt: RowFormatContext): ScriptIOFormat = fmt match { + case c: RowFormatDelimitedContext => + getRowFormatDelimited(c) + + case c: RowFormatSerdeContext => + throw new ParseException("TRANSFORM with serde is only supported in hive mode", ctx) + + // SPARK-32106: When there is no definition about format, we return empty result + // to use a built-in default Serde in SparkScriptTransformationExec. + case null => + (Nil, None, Seq.empty, None) + } + + val (inFormat, inSerdeClass, inSerdeProps, reader) = format(inRowFormat) + + val (outFormat, outSerdeClass, outSerdeProps, writer) = format(outRowFormat) + + ScriptInputOutputSchema( + inFormat, outFormat, + inSerdeClass, outSerdeClass, + inSerdeProps, outSerdeProps, + reader, writer, + schemaLess) + } + + /** + * Create a logical plan for a given 'FROM' clause. Note that we support multiple (comma + * separated) relations here, these get converted into a single plan by condition-less inner join. + */ + override def visitFromClause(ctx: FromClauseContext): LogicalPlan = withOrigin(ctx) { + val from = ctx.relation.asScala.foldLeft(null: LogicalPlan) { (left, relation) => + val right = plan(relation.relationPrimary) + val join = right.optionalMap(left) { (left, right) => + if (relation.LATERAL != null) { + if (!relation.relationPrimary.isInstanceOf[AliasedQueryContext]) { + throw new ParseException(s"LATERAL can only be used with subquery", relation.relationPrimary) + } + LateralJoin(left, LateralSubquery(right), Inner, None) + } else { + Join(left, right, Inner, None, JoinHint.NONE) + } + } + withJoinRelations(join, relation) + } + if (ctx.pivotClause() != null) { + if (!ctx.lateralView.isEmpty) { + throw new ParseException("LATERAL cannot be used together with PIVOT in FROM clause", ctx) + } + withPivot(ctx.pivotClause, from) + } else { + ctx.lateralView.asScala.foldLeft(from)(withGenerate) + } + } + + /** + * Connect two queries by a Set operator. + * + * Supported Set operators are: + * - UNION [ DISTINCT | ALL ] + * - EXCEPT [ DISTINCT | ALL ] + * - MINUS [ DISTINCT | ALL ] + * - INTERSECT [DISTINCT | ALL] + */ + override def visitSetOperation(ctx: SetOperationContext): LogicalPlan = withOrigin(ctx) { + val left = plan(ctx.left) + val right = plan(ctx.right) + val all = Option(ctx.setQuantifier()).exists(_.ALL != null) + ctx.operator.getType match { + case HoodieSqlBaseParser.UNION if all => + Union(left, right) + case HoodieSqlBaseParser.UNION => + Distinct(Union(left, right)) + case HoodieSqlBaseParser.INTERSECT if all => + Intersect(left, right, isAll = true) + case HoodieSqlBaseParser.INTERSECT => + Intersect(left, right, isAll = false) + case HoodieSqlBaseParser.EXCEPT if all => + Except(left, right, isAll = true) + case HoodieSqlBaseParser.EXCEPT => + Except(left, right, isAll = false) + case HoodieSqlBaseParser.SETMINUS if all => + Except(left, right, isAll = true) + case HoodieSqlBaseParser.SETMINUS => + Except(left, right, isAll = false) + } + } + + /** + * Add a [[WithWindowDefinition]] operator to a logical plan. + */ + private def withWindowClause( + ctx: WindowClauseContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + // Collect all window specifications defined in the WINDOW clause. + val baseWindowTuples = ctx.namedWindow.asScala.map { + wCtx => + (wCtx.name.getText, typedVisit[WindowSpec](wCtx.windowSpec)) + } + baseWindowTuples.groupBy(_._1).foreach { kv => + if (kv._2.size > 1) { + throw new ParseException(s"The definition of window '${kv._1}' is repetitive", ctx) + } + } + val baseWindowMap = baseWindowTuples.toMap + + // Handle cases like + // window w1 as (partition by p_mfgr order by p_name + // range between 2 preceding and 2 following), + // w2 as w1 + val windowMapView = baseWindowMap.mapValues { + case WindowSpecReference(name) => + baseWindowMap.get(name) match { + case Some(spec: WindowSpecDefinition) => + spec + case Some(ref) => + throw new ParseException(s"Window reference '$name' is not a window specification", ctx) + case None => + throw new ParseException(s"Cannot resolve window reference '$name'", ctx) + } + case spec: WindowSpecDefinition => spec + } + + // Note that mapValues creates a view instead of materialized map. We force materialization by + // mapping over identity. + WithWindowDefinition(windowMapView.map(identity).toMap, query) + } + + /** + * Add an [[Aggregate]] to a logical plan. + */ + private def withAggregationClause( + ctx: AggregationClauseContext, + selectExpressions: Seq[NamedExpression], + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + if (ctx.groupingExpressionsWithGroupingAnalytics.isEmpty) { + val groupByExpressions = expressionList(ctx.groupingExpressions) + if (ctx.GROUPING != null) { + // GROUP BY ... GROUPING SETS (...) + // `groupByExpressions` can be non-empty for Hive compatibility. It may add extra grouping + // expressions that do not exist in GROUPING SETS (...), and the value is always null. + // For example, `SELECT a, b, c FROM ... GROUP BY a, b, c GROUPING SETS (a, b)`, the output + // of column `c` is always null. + val groupingSets = + ctx.groupingSet.asScala.map(_.expression.asScala.map(e => expression(e)).toSeq) + Aggregate(Seq(GroupingSets(groupingSets.toSeq, groupByExpressions)), + selectExpressions, query) + } else { + // GROUP BY .... (WITH CUBE | WITH ROLLUP)? + val mappedGroupByExpressions = if (ctx.CUBE != null) { + Seq(Cube(groupByExpressions.map(Seq(_)))) + } else if (ctx.ROLLUP != null) { + Seq(Rollup(groupByExpressions.map(Seq(_)))) + } else { + groupByExpressions + } + Aggregate(mappedGroupByExpressions, selectExpressions, query) + } + } else { + val groupByExpressions = + ctx.groupingExpressionsWithGroupingAnalytics.asScala + .map(groupByExpr => { + val groupingAnalytics = groupByExpr.groupingAnalytics + if (groupingAnalytics != null) { + visitGroupingAnalytics(groupingAnalytics) + } else { + expression(groupByExpr.expression) + } + }) + Aggregate(groupByExpressions.toSeq, selectExpressions, query) + } + } + + override def visitGroupingAnalytics( + groupingAnalytics: GroupingAnalyticsContext): BaseGroupingSets = { + val groupingSets = groupingAnalytics.groupingSet.asScala + .map(_.expression.asScala.map(e => expression(e)).toSeq) + if (groupingAnalytics.CUBE != null) { + // CUBE(A, B, (A, B), ()) is not supported. + if (groupingSets.exists(_.isEmpty)) { + throw new ParseException(s"Empty set in CUBE grouping sets is not supported.", groupingAnalytics) + } + Cube(groupingSets.toSeq) + } else if (groupingAnalytics.ROLLUP != null) { + // ROLLUP(A, B, (A, B), ()) is not supported. + if (groupingSets.exists(_.isEmpty)) { + throw new ParseException(s"Empty set in ROLLUP grouping sets is not supported.", groupingAnalytics) + } + Rollup(groupingSets.toSeq) + } else { + assert(groupingAnalytics.GROUPING != null && groupingAnalytics.SETS != null) + val groupingSets = groupingAnalytics.groupingElement.asScala.flatMap { expr => + val groupingAnalytics = expr.groupingAnalytics() + if (groupingAnalytics != null) { + visitGroupingAnalytics(groupingAnalytics).selectedGroupByExprs + } else { + Seq(expr.groupingSet().expression().asScala.map(e => expression(e)).toSeq) + } + } + GroupingSets(groupingSets.toSeq) + } + } + + /** + * Add [[UnresolvedHint]]s to a logical plan. + */ + private def withHints( + ctx: HintContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + var plan = query + ctx.hintStatements.asScala.reverse.foreach { stmt => + plan = UnresolvedHint(stmt.hintName.getText, + stmt.parameters.asScala.map(expression).toSeq, plan) + } + plan + } + + /** + * Add a [[Pivot]] to a logical plan. + */ + private def withPivot( + ctx: PivotClauseContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + val aggregates = Option(ctx.aggregates).toSeq + .flatMap(_.namedExpression.asScala) + .map(typedVisit[Expression]) + val pivotColumn = if (ctx.pivotColumn.identifiers.size == 1) { + UnresolvedAttribute.quoted(ctx.pivotColumn.identifier.getText) + } else { + CreateStruct( + ctx.pivotColumn.identifiers.asScala.map( + identifier => UnresolvedAttribute.quoted(identifier.getText)).toSeq) + } + val pivotValues = ctx.pivotValues.asScala.map(visitPivotValue) + Pivot(None, pivotColumn, pivotValues.toSeq, aggregates, query) + } + + /** + * Create a Pivot column value with or without an alias. + */ + override def visitPivotValue(ctx: PivotValueContext): Expression = withOrigin(ctx) { + val e = expression(ctx.expression) + if (ctx.identifier != null) { + Alias(e, ctx.identifier.getText)() + } else { + e + } + } + + /** + * Add a [[Generate]] (Lateral View) to a logical plan. + */ + private def withGenerate( + query: LogicalPlan, + ctx: LateralViewContext): LogicalPlan = withOrigin(ctx) { + val expressions = expressionList(ctx.expression) + Generate( + UnresolvedGenerator(visitFunctionName(ctx.qualifiedName), expressions), + unrequiredChildIndex = Nil, + outer = ctx.OUTER != null, + // scalastyle:off caselocale + Some(ctx.tblName.getText.toLowerCase), + // scalastyle:on caselocale + ctx.colName.asScala.map(_.getText).map(UnresolvedAttribute.quoted).toSeq, + query) + } + + /** + * Create a single relation referenced in a FROM clause. This method is used when a part of the + * join condition is nested, for example: + * {{{ + * select * from t1 join (t2 cross join t3) on col1 = col2 + * }}} + */ + override def visitRelation(ctx: RelationContext): LogicalPlan = withOrigin(ctx) { + withJoinRelations(plan(ctx.relationPrimary), ctx) + } + + /** + * Join one more [[LogicalPlan]]s to the current logical plan. + */ + private def withJoinRelations(base: LogicalPlan, ctx: RelationContext): LogicalPlan = { + ctx.joinRelation.asScala.foldLeft(base) { (left, join) => + withOrigin(join) { + val baseJoinType = join.joinType match { + case null => Inner + case jt if jt.CROSS != null => Cross + case jt if jt.FULL != null => FullOuter + case jt if jt.SEMI != null => LeftSemi + case jt if jt.ANTI != null => LeftAnti + case jt if jt.LEFT != null => LeftOuter + case jt if jt.RIGHT != null => RightOuter + case _ => Inner + } + + if (join.LATERAL != null && !join.right.isInstanceOf[AliasedQueryContext]) { + throw new ParseException(s"LATERAL can only be used with subquery", join.right) + } + + // Resolve the join type and join condition + val (joinType, condition) = Option(join.joinCriteria) match { + case Some(c) if c.USING != null => + if (join.LATERAL != null) { + throw new ParseException("LATERAL join with USING join is not supported", ctx) + } + (UsingJoin(baseJoinType, visitIdentifierList(c.identifierList)), None) + case Some(c) if c.booleanExpression != null => + (baseJoinType, Option(expression(c.booleanExpression))) + case Some(c) => + throw new ParseException(s"Unimplemented joinCriteria: $c", ctx) + case None if join.NATURAL != null => + if (join.LATERAL != null) { + throw new ParseException("LATERAL join with NATURAL join is not supported", ctx) + } + if (baseJoinType == Cross) { + throw new ParseException("NATURAL CROSS JOIN is not supported", ctx) + } + (NaturalJoin(baseJoinType), None) + case None => + (baseJoinType, None) + } + if (join.LATERAL != null) { + if (!Seq(Inner, Cross, LeftOuter).contains(joinType)) { + throw new ParseException(s"Unsupported LATERAL join type ${joinType.toString}", ctx) + } + LateralJoin(left, LateralSubquery(plan(join.right)), joinType, condition) + } else { + Join(left, plan(join.right), joinType, condition, JoinHint.NONE) + } + } + } + } + + /** + * Add a [[Sample]] to a logical plan. + * + * This currently supports the following sampling methods: + * - TABLESAMPLE(x ROWS): Sample the table down to the given number of rows. + * - TABLESAMPLE(x PERCENT): Sample the table down to the given percentage. Note that percentages + * are defined as a number between 0 and 100. + * - TABLESAMPLE(BUCKET x OUT OF y): Sample the table down to a 'x' divided by 'y' fraction. + */ + private def withSample(ctx: SampleContext, query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + // Create a sampled plan if we need one. + def sample(fraction: Double): Sample = { + // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling + // function takes X PERCENT as the input and the range of X is [0, 100], we need to + // adjust the fraction. + val eps = RandomSampler.roundingEpsilon + validate(fraction >= 0.0 - eps && fraction <= 1.0 + eps, + s"Sampling fraction ($fraction) must be on interval [0, 1]", + ctx) + Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query) + } + + if (ctx.sampleMethod() == null) { + throw new ParseException("TABLESAMPLE does not accept empty inputs.", ctx) + } + + ctx.sampleMethod() match { + case ctx: SampleByRowsContext => + Limit(expression(ctx.expression), query) + + case ctx: SampleByPercentileContext => + val fraction = ctx.percentage.getText.toDouble + val sign = if (ctx.negativeSign == null) 1 else -1 + sample(sign * fraction / 100.0d) + + case ctx: SampleByBytesContext => + val bytesStr = ctx.bytes.getText + if (bytesStr.matches("[0-9]+[bBkKmMgG]")) { + throw new ParseException(s"TABLESAMPLE(byteLengthLiteral) is not supported", ctx) + } else { + throw new ParseException(s"$bytesStr is not a valid byte length literal, " + + "expected syntax: DIGIT+ ('B' | 'K' | 'M' | 'G')", ctx) + } + + case ctx: SampleByBucketContext if ctx.ON() != null => + if (ctx.identifier != null) { + throw new ParseException(s"TABLESAMPLE(BUCKET x OUT OF y ON colname) is not supported", ctx) + } else { + throw new ParseException(s"TABLESAMPLE(BUCKET x OUT OF y ON function) is not supported", ctx) + } + + case ctx: SampleByBucketContext => + sample(ctx.numerator.getText.toDouble / ctx.denominator.getText.toDouble) + } + } + + /** + * Create a logical plan for a sub-query. + */ + override def visitSubquery(ctx: SubqueryContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.query) + } + + /** + * Create an un-aliased table reference. This is typically used for top-level table references, + * for example: + * {{{ + * INSERT INTO db.tbl2 + * TABLE db.tbl1 + * }}} + */ + override def visitTable(ctx: TableContext): LogicalPlan = withOrigin(ctx) { + UnresolvedRelation(visitMultipartIdentifier(ctx.multipartIdentifier)) + } + + /** + * Create a table-valued function call with arguments, e.g. range(1000) + */ + override def visitTableValuedFunction(ctx: TableValuedFunctionContext) + : LogicalPlan = withOrigin(ctx) { + val func = ctx.functionTable + val aliases = if (func.tableAlias.identifierList != null) { + visitIdentifierList(func.tableAlias.identifierList) + } else { + Seq.empty + } + val name = getFunctionIdentifier(func.functionName) + if (name.database.nonEmpty) { + operationNotAllowed(s"table valued function cannot specify database name: $name", ctx) + } + + val tvf = UnresolvedTableValuedFunction(name, func.expression.asScala.map(expression).toSeq) + + val tvfAliases = if (aliases.nonEmpty) UnresolvedTVFAliases(name, tvf, aliases) else tvf + + tvfAliases.optionalMap(func.tableAlias.strictIdentifier)(aliasPlan) + } + + /** + * Create an inline table (a virtual table in Hive parlance). + */ + override def visitInlineTable(ctx: InlineTableContext): LogicalPlan = withOrigin(ctx) { + // Get the backing expressions. + val rows = ctx.expression.asScala.map { e => + expression(e) match { + // inline table comes in two styles: + // style 1: values (1), (2), (3) -- multiple columns are supported + // style 2: values 1, 2, 3 -- only a single column is supported here + case struct: CreateNamedStruct => struct.valExprs // style 1 + case child => Seq(child) // style 2 + } + } + + val aliases = if (ctx.tableAlias.identifierList != null) { + visitIdentifierList(ctx.tableAlias.identifierList) + } else { + Seq.tabulate(rows.head.size)(i => s"col${i + 1}") + } + + val table = UnresolvedInlineTable(aliases, rows.toSeq) + table.optionalMap(ctx.tableAlias.strictIdentifier)(aliasPlan) + } + + /** + * Create an alias (SubqueryAlias) for a join relation. This is practically the same as + * visitAliasedQuery and visitNamedExpression, ANTLR4 however requires us to use 3 different + * hooks. We could add alias names for output columns, for example: + * {{{ + * SELECT a, b, c, d FROM (src1 s1 INNER JOIN src2 s2 ON s1.id = s2.id) dst(a, b, c, d) + * }}} + */ + override def visitAliasedRelation(ctx: AliasedRelationContext): LogicalPlan = withOrigin(ctx) { + val relation = plan(ctx.relation).optionalMap(ctx.sample)(withSample) + mayApplyAliasPlan(ctx.tableAlias, relation) + } + + /** + * Create an alias (SubqueryAlias) for a sub-query. This is practically the same as + * visitAliasedRelation and visitNamedExpression, ANTLR4 however requires us to use 3 different + * hooks. We could add alias names for output columns, for example: + * {{{ + * SELECT col1, col2 FROM testData AS t(col1, col2) + * }}} + */ + override def visitAliasedQuery(ctx: AliasedQueryContext): LogicalPlan = withOrigin(ctx) { + val relation = plan(ctx.query).optionalMap(ctx.sample)(withSample) + if (ctx.tableAlias.strictIdentifier == null) { + // For un-aliased subqueries, use a default alias name that is not likely to conflict with + // normal subquery names, so that parent operators can only access the columns in subquery by + // unqualified names. Users can still use this special qualifier to access columns if they + // know it, but that's not recommended. + SubqueryAlias("__auto_generated_subquery_name", relation) + } else { + mayApplyAliasPlan(ctx.tableAlias, relation) + } + } + + /** + * Create an alias ([[SubqueryAlias]]) for a [[LogicalPlan]]. + */ + private def aliasPlan(alias: ParserRuleContext, plan: LogicalPlan): LogicalPlan = { + SubqueryAlias(alias.getText, plan) + } + + /** + * If aliases specified in a FROM clause, create a subquery alias ([[SubqueryAlias]]) and + * column aliases for a [[LogicalPlan]]. + */ + private def mayApplyAliasPlan(tableAlias: TableAliasContext, plan: LogicalPlan): LogicalPlan = { + if (tableAlias.strictIdentifier != null) { + val alias = tableAlias.strictIdentifier.getText + if (tableAlias.identifierList != null) { + val columnNames = visitIdentifierList(tableAlias.identifierList) + SubqueryAlias(alias, UnresolvedSubqueryColumnAliases(columnNames, plan)) + } else { + SubqueryAlias(alias, plan) + } + } else { + plan + } + } + + /** + * Create a Sequence of Strings for a parenthesis enclosed alias list. + */ + override def visitIdentifierList(ctx: IdentifierListContext): Seq[String] = withOrigin(ctx) { + visitIdentifierSeq(ctx.identifierSeq) + } + + /** + * Create a Sequence of Strings for an identifier list. + */ + override def visitIdentifierSeq(ctx: IdentifierSeqContext): Seq[String] = withOrigin(ctx) { + ctx.ident.asScala.map(_.getText).toSeq + } + + /* ******************************************************************************************** + * Table Identifier parsing + * ******************************************************************************************** */ + + /** + * Create a [[TableIdentifier]] from a 'tableName' or 'databaseName'.'tableName' pattern. + */ + override def visitTableIdentifier( + ctx: TableIdentifierContext): TableIdentifier = withOrigin(ctx) { + TableIdentifier(ctx.table.getText, Option(ctx.db).map(_.getText)) + } + + /** + * Create a [[FunctionIdentifier]] from a 'functionName' or 'databaseName'.'functionName' pattern. + */ + override def visitFunctionIdentifier( + ctx: FunctionIdentifierContext): FunctionIdentifier = withOrigin(ctx) { + FunctionIdentifier(ctx.function.getText, Option(ctx.db).map(_.getText)) + } + + /** + * Create a multi-part identifier. + */ + override def visitMultipartIdentifier(ctx: MultipartIdentifierContext): Seq[String] = + withOrigin(ctx) { + ctx.parts.asScala.map(_.getText).toSeq + } + + /* ******************************************************************************************** + * Expression parsing + * ******************************************************************************************** */ + + /** + * Create an expression from the given context. This method just passes the context on to the + * visitor and only takes care of typing (We assume that the visitor returns an Expression here). + */ + protected def expression(ctx: ParserRuleContext): Expression = typedVisit(ctx) + + /** + * Create sequence of expressions from the given sequence of contexts. + */ + private def expressionList(trees: java.util.List[ExpressionContext]): Seq[Expression] = { + trees.asScala.map(expression).toSeq + } + + /** + * Create a star (i.e. all) expression; this selects all elements (in the specified object). + * Both un-targeted (global) and targeted aliases are supported. + */ + override def visitStar(ctx: StarContext): Expression = withOrigin(ctx) { + UnresolvedStar(Option(ctx.qualifiedName()).map(_.identifier.asScala.map(_.getText).toSeq)) + } + + /** + * Create an aliased expression if an alias is specified. Both single and multi-aliases are + * supported. + */ + override def visitNamedExpression(ctx: NamedExpressionContext): Expression = withOrigin(ctx) { + val e = expression(ctx.expression) + if (ctx.name != null) { + Alias(e, ctx.name.getText)() + } else if (ctx.identifierList != null) { + MultiAlias(e, visitIdentifierList(ctx.identifierList)) + } else { + e + } + } + + /** + * Combine a number of boolean expressions into a balanced expression tree. These expressions are + * either combined by a logical [[And]] or a logical [[Or]]. + * + * A balanced binary tree is created because regular left recursive trees cause considerable + * performance degradations and can cause stack overflows. + */ + override def visitLogicalBinary(ctx: LogicalBinaryContext): Expression = withOrigin(ctx) { + val expressionType = ctx.operator.getType + val expressionCombiner = expressionType match { + case HoodieSqlBaseParser.AND => And.apply _ + case HoodieSqlBaseParser.OR => Or.apply _ + } + + // Collect all similar left hand contexts. + val contexts = ArrayBuffer(ctx.right) + var current = ctx.left + + def collectContexts: Boolean = current match { + case lbc: LogicalBinaryContext if lbc.operator.getType == expressionType => + contexts += lbc.right + current = lbc.left + true + case _ => + contexts += current + false + } + + while (collectContexts) { + // No body - all updates take place in the collectContexts. + } + + // Reverse the contexts to have them in the same sequence as in the SQL statement & turn them + // into expressions. + val expressions = contexts.reverseMap(expression) + + // Create a balanced tree. + def reduceToExpressionTree(low: Int, high: Int): Expression = high - low match { + case 0 => + expressions(low) + case 1 => + expressionCombiner(expressions(low), expressions(high)) + case x => + val mid = low + x / 2 + expressionCombiner( + reduceToExpressionTree(low, mid), + reduceToExpressionTree(mid + 1, high)) + } + + reduceToExpressionTree(0, expressions.size - 1) + } + + /** + * Invert a boolean expression. + */ + override def visitLogicalNot(ctx: LogicalNotContext): Expression = withOrigin(ctx) { + Not(expression(ctx.booleanExpression())) + } + + /** + * Create a filtering correlated sub-query (EXISTS). + */ + override def visitExists(ctx: ExistsContext): Expression = { + Exists(plan(ctx.query)) + } + + /** + * Create a comparison expression. This compares two expressions. The following comparison + * operators are supported: + * - Equal: '=' or '==' + * - Null-safe Equal: '<=>' + * - Not Equal: '<>' or '!=' + * - Less than: '<' + * - Less then or Equal: '<=' + * - Greater than: '>' + * - Greater then or Equal: '>=' + */ + override def visitComparison(ctx: ComparisonContext): Expression = withOrigin(ctx) { + val left = expression(ctx.left) + val right = expression(ctx.right) + val operator = ctx.comparisonOperator().getChild(0).asInstanceOf[TerminalNode] + operator.getSymbol.getType match { + case HoodieSqlBaseParser.EQ => + EqualTo(left, right) + case HoodieSqlBaseParser.NSEQ => + EqualNullSafe(left, right) + case HoodieSqlBaseParser.NEQ | HoodieSqlBaseParser.NEQJ => + Not(EqualTo(left, right)) + case HoodieSqlBaseParser.LT => + LessThan(left, right) + case HoodieSqlBaseParser.LTE => + LessThanOrEqual(left, right) + case HoodieSqlBaseParser.GT => + GreaterThan(left, right) + case HoodieSqlBaseParser.GTE => + GreaterThanOrEqual(left, right) + } + } + + /** + * Create a predicated expression. A predicated expression is a normal expression with a + * predicate attached to it, for example: + * {{{ + * a + 1 IS NULL + * }}} + */ + override def visitPredicated(ctx: PredicatedContext): Expression = withOrigin(ctx) { + val e = expression(ctx.valueExpression) + if (ctx.predicate != null) { + withPredicate(e, ctx.predicate) + } else { + e + } + } + + /** + * Add a predicate to the given expression. Supported expressions are: + * - (NOT) BETWEEN + * - (NOT) IN + * - (NOT) LIKE (ANY | SOME | ALL) + * - (NOT) RLIKE + * - IS (NOT) NULL. + * - IS (NOT) (TRUE | FALSE | UNKNOWN) + * - IS (NOT) DISTINCT FROM + */ + private def withPredicate(e: Expression, ctx: PredicateContext): Expression = withOrigin(ctx) { + // Invert a predicate if it has a valid NOT clause. + def invertIfNotDefined(e: Expression): Expression = ctx.NOT match { + case null => e + case not => Not(e) + } + + def getValueExpressions(e: Expression): Seq[Expression] = e match { + case c: CreateNamedStruct => c.valExprs + case other => Seq(other) + } + + // Create the predicate. + ctx.kind.getType match { + case HoodieSqlBaseParser.BETWEEN => + // BETWEEN is translated to lower <= e && e <= upper + invertIfNotDefined(And( + GreaterThanOrEqual(e, expression(ctx.lower)), + LessThanOrEqual(e, expression(ctx.upper)))) + case HoodieSqlBaseParser.IN if ctx.query != null => + invertIfNotDefined(InSubquery(getValueExpressions(e), ListQuery(plan(ctx.query)))) + case HoodieSqlBaseParser.IN => + invertIfNotDefined(In(e, ctx.expression.asScala.map(expression).toSeq)) + case HoodieSqlBaseParser.LIKE => + Option(ctx.quantifier).map(_.getType) match { + case Some(HoodieSqlBaseParser.ANY) | Some(HoodieSqlBaseParser.SOME) => + validate(!ctx.expression.isEmpty, "Expected something between '(' and ')'.", ctx) + val expressions = expressionList(ctx.expression) + if (expressions.forall(_.foldable) && expressions.forall(_.dataType == StringType)) { + // If there are many pattern expressions, will throw StackOverflowError. + // So we use LikeAny or NotLikeAny instead. + val patterns = expressions.map(_.eval(EmptyRow).asInstanceOf[UTF8String]) + ctx.NOT match { + case null => LikeAny(e, patterns) + case _ => NotLikeAny(e, patterns) + } + } else { + ctx.expression.asScala.map(expression) + .map(p => invertIfNotDefined(new Like(e, p))).toSeq.reduceLeft(Or) + } + case Some(HoodieSqlBaseParser.ALL) => + validate(!ctx.expression.isEmpty, "Expected something between '(' and ')'.", ctx) + val expressions = expressionList(ctx.expression) + if (expressions.forall(_.foldable) && expressions.forall(_.dataType == StringType)) { + // If there are many pattern expressions, will throw StackOverflowError. + // So we use LikeAll or NotLikeAll instead. + val patterns = expressions.map(_.eval(EmptyRow).asInstanceOf[UTF8String]) + ctx.NOT match { + case null => LikeAll(e, patterns) + case _ => NotLikeAll(e, patterns) + } + } else { + ctx.expression.asScala.map(expression) + .map(p => invertIfNotDefined(new Like(e, p))).toSeq.reduceLeft(And) + } + case _ => + val escapeChar = Option(ctx.escapeChar).map(string).map { str => + if (str.length != 1) { + throw new ParseException("Invalid escape string. Escape string must contain only one character.", ctx) + } + str.charAt(0) + }.getOrElse('\\') + invertIfNotDefined(Like(e, expression(ctx.pattern), escapeChar)) + } + case HoodieSqlBaseParser.RLIKE => + invertIfNotDefined(RLike(e, expression(ctx.pattern))) + case HoodieSqlBaseParser.NULL if ctx.NOT != null => + IsNotNull(e) + case HoodieSqlBaseParser.NULL => + IsNull(e) + case HoodieSqlBaseParser.TRUE => ctx.NOT match { + case null => EqualNullSafe(e, Literal(true)) + case _ => Not(EqualNullSafe(e, Literal(true))) + } + case HoodieSqlBaseParser.FALSE => ctx.NOT match { + case null => EqualNullSafe(e, Literal(false)) + case _ => Not(EqualNullSafe(e, Literal(false))) + } + case HoodieSqlBaseParser.UNKNOWN => ctx.NOT match { + case null => IsUnknown(e) + case _ => IsNotUnknown(e) + } + case HoodieSqlBaseParser.DISTINCT if ctx.NOT != null => + EqualNullSafe(e, expression(ctx.right)) + case HoodieSqlBaseParser.DISTINCT => + Not(EqualNullSafe(e, expression(ctx.right))) + } + } + + /** + * Create a binary arithmetic expression. The following arithmetic operators are supported: + * - Multiplication: '*' + * - Division: '/' + * - Hive Long Division: 'DIV' + * - Modulo: '%' + * - Addition: '+' + * - Subtraction: '-' + * - Binary AND: '&' + * - Binary XOR + * - Binary OR: '|' + */ + override def visitArithmeticBinary(ctx: ArithmeticBinaryContext): Expression = withOrigin(ctx) { + val left = expression(ctx.left) + val right = expression(ctx.right) + ctx.operator.getType match { + case HoodieSqlBaseParser.ASTERISK => + Multiply(left, right) + case HoodieSqlBaseParser.SLASH => + Divide(left, right) + case HoodieSqlBaseParser.PERCENT => + Remainder(left, right) + case HoodieSqlBaseParser.DIV => + IntegralDivide(left, right) + case HoodieSqlBaseParser.PLUS => + Add(left, right) + case HoodieSqlBaseParser.MINUS => + Subtract(left, right) + case HoodieSqlBaseParser.CONCAT_PIPE => + Concat(left :: right :: Nil) + case HoodieSqlBaseParser.AMPERSAND => + BitwiseAnd(left, right) + case HoodieSqlBaseParser.HAT => + BitwiseXor(left, right) + case HoodieSqlBaseParser.PIPE => + BitwiseOr(left, right) + } + } + + /** + * Create a unary arithmetic expression. The following arithmetic operators are supported: + * - Plus: '+' + * - Minus: '-' + * - Bitwise Not: '~' + */ + override def visitArithmeticUnary(ctx: ArithmeticUnaryContext): Expression = withOrigin(ctx) { + val value = expression(ctx.valueExpression) + ctx.operator.getType match { + case HoodieSqlBaseParser.PLUS => + UnaryPositive(value) + case HoodieSqlBaseParser.MINUS => + UnaryMinus(value) + case HoodieSqlBaseParser.TILDE => + BitwiseNot(value) + } + } + + override def visitCurrentLike(ctx: CurrentLikeContext): Expression = withOrigin(ctx) { + if (conf.ansiEnabled) { + ctx.name.getType match { + case HoodieSqlBaseParser.CURRENT_DATE => + CurrentDate() + case HoodieSqlBaseParser.CURRENT_TIMESTAMP => + CurrentTimestamp() + case HoodieSqlBaseParser.CURRENT_USER => + CurrentUser() + } + } else { + // If the parser is not in ansi mode, we should return `UnresolvedAttribute`, in case there + // are columns named `CURRENT_DATE` or `CURRENT_TIMESTAMP`. + UnresolvedAttribute.quoted(ctx.name.getText) + } + } + + /** + * Create a [[Cast]] expression. + */ + override def visitCast(ctx: CastContext): Expression = withOrigin(ctx) { + val rawDataType = typedVisit[DataType](ctx.dataType()) + val dataType = CharVarcharUtils.replaceCharVarcharWithStringForCast(rawDataType) + val cast = ctx.name.getType match { + case HoodieSqlBaseParser.CAST => + Cast(expression(ctx.expression), dataType) + + case HoodieSqlBaseParser.TRY_CAST => + Cast(expression(ctx.expression), dataType, evalMode = EvalMode.TRY) + } + cast.setTagValue(Cast.USER_SPECIFIED_CAST, true) + cast + } + + /** + * Create a [[CreateStruct]] expression. + */ + override def visitStruct(ctx: StructContext): Expression = withOrigin(ctx) { + CreateStruct.create(ctx.argument.asScala.map(expression).toSeq) + } + + /** + * Create a [[First]] expression. + */ + override def visitFirst(ctx: FirstContext): Expression = withOrigin(ctx) { + val ignoreNullsExpr = ctx.IGNORE != null + First(expression(ctx.expression), ignoreNullsExpr).toAggregateExpression() + } + + /** + * Create a [[Last]] expression. + */ + override def visitLast(ctx: LastContext): Expression = withOrigin(ctx) { + val ignoreNullsExpr = ctx.IGNORE != null + Last(expression(ctx.expression), ignoreNullsExpr).toAggregateExpression() + } + + /** + * Create a Position expression. + */ + override def visitPosition(ctx: PositionContext): Expression = withOrigin(ctx) { + new StringLocate(expression(ctx.substr), expression(ctx.str)) + } + + /** + * Create a Extract expression. + */ + override def visitExtract(ctx: ExtractContext): Expression = withOrigin(ctx) { + val arguments = Seq(Literal(ctx.field.getText), expression(ctx.source)) + UnresolvedFunction("extract", arguments, isDistinct = false) + } + + /** + * Create a Substring/Substr expression. + */ + override def visitSubstring(ctx: SubstringContext): Expression = withOrigin(ctx) { + if (ctx.len != null) { + Substring(expression(ctx.str), expression(ctx.pos), expression(ctx.len)) + } else { + new Substring(expression(ctx.str), expression(ctx.pos)) + } + } + + /** + * Create a Trim expression. + */ + override def visitTrim(ctx: TrimContext): Expression = withOrigin(ctx) { + val srcStr = expression(ctx.srcStr) + val trimStr = Option(ctx.trimStr).map(expression) + Option(ctx.trimOption).map(_.getType).getOrElse(HoodieSqlBaseParser.BOTH) match { + case HoodieSqlBaseParser.BOTH => + StringTrim(srcStr, trimStr) + case HoodieSqlBaseParser.LEADING => + StringTrimLeft(srcStr, trimStr) + case HoodieSqlBaseParser.TRAILING => + StringTrimRight(srcStr, trimStr) + case other => + throw new ParseException("Function trim doesn't support with " + + s"type $other. Please use BOTH, LEADING or TRAILING as trim type", ctx) + } + } + + /** + * Create a Overlay expression. + */ + override def visitOverlay(ctx: OverlayContext): Expression = withOrigin(ctx) { + val input = expression(ctx.input) + val replace = expression(ctx.replace) + val position = expression(ctx.position) + val lengthOpt = Option(ctx.length).map(expression) + lengthOpt match { + case Some(length) => Overlay(input, replace, position, length) + case None => new Overlay(input, replace, position) + } + } + + /** + * Create a (windowed) Function expression. + */ + override def visitFunctionCall(ctx: FunctionCallContext): Expression = withOrigin(ctx) { + // Create the function call. + val name = ctx.functionName.getText + val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null) + // Call `toSeq`, otherwise `ctx.argument.asScala.map(expression)` is `Buffer` in Scala 2.13 + val arguments = ctx.argument.asScala.map(expression).toSeq match { + case Seq(UnresolvedStar(None)) + if name.toLowerCase(Locale.ROOT) == "count" && !isDistinct => + // Transform COUNT(*) into COUNT(1). + Seq(Literal(1)) + case expressions => + expressions + } + val filter = Option(ctx.where).map(expression(_)) + val ignoreNulls = + Option(ctx.nullsOption).map(_.getType == HoodieSqlBaseParser.IGNORE).getOrElse(false) + val function = UnresolvedFunction( + getFunctionMultiparts(ctx.functionName), arguments, isDistinct, filter, ignoreNulls) + + // Check if the function is evaluated in a windowed context. + ctx.windowSpec match { + case spec: WindowRefContext => + UnresolvedWindowExpression(function, visitWindowRef(spec)) + case spec: WindowDefContext => + WindowExpression(function, visitWindowDef(spec)) + case _ => function + } + } + + /** + * Create a function database (optional) and name pair. + */ + protected def visitFunctionName(ctx: QualifiedNameContext): FunctionIdentifier = { + visitFunctionName(ctx, ctx.identifier().asScala.map(_.getText).toSeq) + } + + /** + * Create a function database (optional) and name pair. + */ + private def visitFunctionName(ctx: ParserRuleContext, texts: Seq[String]): FunctionIdentifier = { + texts match { + case Seq(db, fn) => FunctionIdentifier(fn, Option(db)) + case Seq(fn) => FunctionIdentifier(fn, None) + case other => + throw new ParseException(s"Unsupported function name '${texts.mkString(".")}'", ctx) + } + } + + /** + * Get a function identifier consist by database (optional) and name. + */ + protected def getFunctionIdentifier(ctx: FunctionNameContext): FunctionIdentifier = { + if (ctx.qualifiedName != null) { + visitFunctionName(ctx.qualifiedName) + } else { + FunctionIdentifier(ctx.getText, None) + } + } + + protected def getFunctionMultiparts(ctx: FunctionNameContext): Seq[String] = { + if (ctx.qualifiedName != null) { + ctx.qualifiedName().identifier().asScala.map(_.getText).toSeq + } else { + Seq(ctx.getText) + } + } + + /** + * Create an [[LambdaFunction]]. + */ + override def visitLambda(ctx: LambdaContext): Expression = withOrigin(ctx) { + val arguments = ctx.identifier().asScala.map { name => + UnresolvedNamedLambdaVariable(UnresolvedAttribute.quoted(name.getText).nameParts) + } + val function = expression(ctx.expression).transformUp { + case a: UnresolvedAttribute => UnresolvedNamedLambdaVariable(a.nameParts) + } + LambdaFunction(function, arguments.toSeq) + } + + /** + * Create a reference to a window frame, i.e. [[WindowSpecReference]]. + */ + override def visitWindowRef(ctx: WindowRefContext): WindowSpecReference = withOrigin(ctx) { + WindowSpecReference(ctx.name.getText) + } + + /** + * Create a window definition, i.e. [[WindowSpecDefinition]]. + */ + override def visitWindowDef(ctx: WindowDefContext): WindowSpecDefinition = withOrigin(ctx) { + // CLUSTER BY ... | PARTITION BY ... ORDER BY ... + val partition = ctx.partition.asScala.map(expression) + val order = ctx.sortItem.asScala.map(visitSortItem) + + // RANGE/ROWS BETWEEN ... + val frameSpecOption = Option(ctx.windowFrame).map { frame => + val frameType = frame.frameType.getType match { + case HoodieSqlBaseParser.RANGE => RangeFrame + case HoodieSqlBaseParser.ROWS => RowFrame + } + + SpecifiedWindowFrame( + frameType, + visitFrameBound(frame.start), + Option(frame.end).map(visitFrameBound).getOrElse(CurrentRow)) + } + + WindowSpecDefinition( + partition.toSeq, + order.toSeq, + frameSpecOption.getOrElse(UnspecifiedFrame)) + } + + /** + * Create or resolve a frame boundary expressions. + */ + override def visitFrameBound(ctx: FrameBoundContext): Expression = withOrigin(ctx) { + def value: Expression = { + val e = expression(ctx.expression) + validate(e.resolved && e.foldable, "Frame bound value must be a literal.", ctx) + e + } + + ctx.boundType.getType match { + case HoodieSqlBaseParser.PRECEDING if ctx.UNBOUNDED != null => + UnboundedPreceding + case HoodieSqlBaseParser.PRECEDING => + UnaryMinus(value) + case HoodieSqlBaseParser.CURRENT => + CurrentRow + case HoodieSqlBaseParser.FOLLOWING if ctx.UNBOUNDED != null => + UnboundedFollowing + case HoodieSqlBaseParser.FOLLOWING => + value + } + } + + /** + * Create a [[CreateStruct]] expression. + */ + override def visitRowConstructor(ctx: RowConstructorContext): Expression = withOrigin(ctx) { + CreateStruct(ctx.namedExpression().asScala.map(expression).toSeq) + } + + /** + * Create a [[ScalarSubquery]] expression. + */ + override def visitSubqueryExpression( + ctx: SubqueryExpressionContext): Expression = withOrigin(ctx) { + ScalarSubquery(plan(ctx.query)) + } + + /** + * Create a value based [[CaseWhen]] expression. This has the following SQL form: + * {{{ + * CASE [expression] + * WHEN [value] THEN [expression] + * ... + * ELSE [expression] + * END + * }}} + */ + override def visitSimpleCase(ctx: SimpleCaseContext): Expression = withOrigin(ctx) { + val e = expression(ctx.value) + val branches = ctx.whenClause.asScala.map { wCtx => + (EqualTo(e, expression(wCtx.condition)), expression(wCtx.result)) + } + CaseWhen(branches.toSeq, Option(ctx.elseExpression).map(expression)) + } + + /** + * Create a condition based [[CaseWhen]] expression. This has the following SQL syntax: + * {{{ + * CASE + * WHEN [predicate] THEN [expression] + * ... + * ELSE [expression] + * END + * }}} + * + * @param ctx the parse tree + * */ + override def visitSearchedCase(ctx: SearchedCaseContext): Expression = withOrigin(ctx) { + val branches = ctx.whenClause.asScala.map { wCtx => + (expression(wCtx.condition), expression(wCtx.result)) + } + CaseWhen(branches.toSeq, Option(ctx.elseExpression).map(expression)) + } + + /** + * Currently only regex in expressions of SELECT statements are supported; in other + * places, e.g., where `(a)?+.+` = 2, regex are not meaningful. + */ + private def canApplyRegex(ctx: ParserRuleContext): Boolean = withOrigin(ctx) { + var parent = ctx.getParent + var rtn = false + while (parent != null) { + if (parent.isInstanceOf[NamedExpressionContext]) { + rtn = true + } + parent = parent.getParent + } + rtn + } + + /** + * Create a dereference expression. The return type depends on the type of the parent. + * If the parent is an [[UnresolvedAttribute]], it can be a [[UnresolvedAttribute]] or + * a [[UnresolvedRegex]] for regex quoted in ``; if the parent is some other expression, + * it can be [[UnresolvedExtractValue]]. + */ + override def visitDereference(ctx: DereferenceContext): Expression = withOrigin(ctx) { + val attr = ctx.fieldName.getText + expression(ctx.base) match { + case unresolved_attr@UnresolvedAttribute(nameParts) => + ctx.fieldName.getStart.getText match { + case escapedIdentifier(columnNameRegex) + if conf.supportQuotedRegexColumnName && canApplyRegex(ctx) => + UnresolvedRegex(columnNameRegex, Some(unresolved_attr.name), + conf.caseSensitiveAnalysis) + case _ => + UnresolvedAttribute(nameParts :+ attr) + } + case e => + UnresolvedExtractValue(e, Literal(attr)) + } + } + + /** + * Create an [[UnresolvedAttribute]] expression or a [[UnresolvedRegex]] if it is a regex + * quoted in `` + */ + override def visitColumnReference(ctx: ColumnReferenceContext): Expression = withOrigin(ctx) { + ctx.getStart.getText match { + case escapedIdentifier(columnNameRegex) + if conf.supportQuotedRegexColumnName && canApplyRegex(ctx) => + UnresolvedRegex(columnNameRegex, None, conf.caseSensitiveAnalysis) + case _ => + UnresolvedAttribute.quoted(ctx.getText) + } + + } + + /** + * Create an [[UnresolvedExtractValue]] expression, this is used for subscript access to an array. + */ + override def visitSubscript(ctx: SubscriptContext): Expression = withOrigin(ctx) { + UnresolvedExtractValue(expression(ctx.value), expression(ctx.index)) + } + + /** + * Create an expression for an expression between parentheses. This is need because the ANTLR + * visitor cannot automatically convert the nested context into an expression. + */ + override def visitParenthesizedExpression( + ctx: ParenthesizedExpressionContext): Expression = withOrigin(ctx) { + expression(ctx.expression) + } + + /** + * Create a [[SortOrder]] expression. + */ + override def visitSortItem(ctx: SortItemContext): SortOrder = withOrigin(ctx) { + val direction = if (ctx.DESC != null) { + Descending + } else { + Ascending + } + val nullOrdering = if (ctx.FIRST != null) { + NullsFirst + } else if (ctx.LAST != null) { + NullsLast + } else { + direction.defaultNullOrdering + } + SortOrder(expression(ctx.expression), direction, nullOrdering, Seq.empty) + } + + /** + * Create a typed Literal expression. A typed literal has the following SQL syntax: + * {{{ + * [TYPE] '[VALUE]' + * }}} + * Currently Date, Timestamp, Interval and Binary typed literals are supported. + */ + override def visitTypeConstructor(ctx: TypeConstructorContext): Literal = withOrigin(ctx) { + val value = string(ctx.STRING) + val valueType = ctx.identifier.getText.toUpperCase(Locale.ROOT) + + def toLiteral[T](f: UTF8String => Option[T], t: DataType): Literal = { + f(UTF8String.fromString(value)).map(Literal(_, t)).getOrElse { + throw new ParseException(s"Cannot parse the $valueType value: $value", ctx) + } + } + + def constructTimestampLTZLiteral(value: String): Literal = { + val zoneId = getZoneId(conf.sessionLocalTimeZone) + val specialTs = convertSpecialTimestamp(value, zoneId).map(Literal(_, TimestampType)) + specialTs.getOrElse(toLiteral(stringToTimestamp(_, zoneId), TimestampType)) + } + + try { + valueType match { + case "DATE" => + val zoneId = getZoneId(conf.sessionLocalTimeZone) + val specialDate = convertSpecialDate(value, zoneId).map(Literal(_, DateType)) + specialDate.getOrElse(toLiteral(stringToDate, DateType)) + // SPARK-36227: Remove TimestampNTZ type support in Spark 3.2 with minimal code changes. + case "TIMESTAMP_NTZ" if isTesting => + convertSpecialTimestampNTZ(value, getZoneId(conf.sessionLocalTimeZone)) + .map(Literal(_, TimestampNTZType)) + .getOrElse(toLiteral(stringToTimestampWithoutTimeZone, TimestampNTZType)) + case "TIMESTAMP_LTZ" if isTesting => + constructTimestampLTZLiteral(value) + case "TIMESTAMP" => + SQLConf.get.timestampType match { + case TimestampNTZType => + convertSpecialTimestampNTZ(value, getZoneId(conf.sessionLocalTimeZone)) + .map(Literal(_, TimestampNTZType)) + .getOrElse { + val containsTimeZonePart = + DateTimeUtils.parseTimestampString(UTF8String.fromString(value))._2.isDefined + // If the input string contains time zone part, return a timestamp with local time + // zone literal. + if (containsTimeZonePart) { + constructTimestampLTZLiteral(value) + } else { + toLiteral(stringToTimestampWithoutTimeZone, TimestampNTZType) + } + } + + case TimestampType => + constructTimestampLTZLiteral(value) + } + + case "INTERVAL" => + val interval = try { + IntervalUtils.stringToInterval(UTF8String.fromString(value)) + } catch { + case e: IllegalArgumentException => + val ex = new ParseException(s"Cannot parse the INTERVAL value: $value", ctx) + ex.setStackTrace(e.getStackTrace) + throw ex + } + if (!conf.legacyIntervalEnabled) { + val units = value + .split("\\s") + .map(_.toLowerCase(Locale.ROOT).stripSuffix("s")) + .filter(s => s != "interval" && s.matches("[a-z]+")) + constructMultiUnitsIntervalLiteral(ctx, interval, units) + } else { + Literal(interval, CalendarIntervalType) + } + case "X" => + val padding = if (value.length % 2 != 0) "0" else "" + Literal(DatatypeConverter.parseHexBinary(padding + value)) + case other => + throw new ParseException(s"Literals of type '$other' are currently not supported.", ctx) + } + } catch { + case e: IllegalArgumentException => + val message = Option(e.getMessage).getOrElse(s"Exception parsing $valueType") + throw new ParseException(message, ctx) + } + } + + /** + * Create a NULL literal expression. + */ + override def visitNullLiteral(ctx: NullLiteralContext): Literal = withOrigin(ctx) { + Literal(null) + } + + /** + * Create a Boolean literal expression. + */ + override def visitBooleanLiteral(ctx: BooleanLiteralContext): Literal = withOrigin(ctx) { + if (ctx.getText.toBoolean) { + Literal.TrueLiteral + } else { + Literal.FalseLiteral + } + } + + /** + * Create an integral literal expression. The code selects the most narrow integral type + * possible, either a BigDecimal, a Long or an Integer is returned. + */ + override def visitIntegerLiteral(ctx: IntegerLiteralContext): Literal = withOrigin(ctx) { + BigDecimal(ctx.getText) match { + case v if v.isValidInt => + Literal(v.intValue) + case v if v.isValidLong => + Literal(v.longValue) + case v => Literal(v.underlying()) + } + } + + /** + * Create a decimal literal for a regular decimal number. + */ + override def visitDecimalLiteral(ctx: DecimalLiteralContext): Literal = withOrigin(ctx) { + Literal(BigDecimal(ctx.getText).underlying()) + } + + /** + * Create a decimal literal for a regular decimal number or a scientific decimal number. + */ + override def visitLegacyDecimalLiteral( + ctx: LegacyDecimalLiteralContext): Literal = withOrigin(ctx) { + Literal(BigDecimal(ctx.getText).underlying()) + } + + /** + * Create a double literal for number with an exponent, e.g. 1E-30 + */ + override def visitExponentLiteral(ctx: ExponentLiteralContext): Literal = { + numericLiteral(ctx, ctx.getText, /* exponent values don't have a suffix */ + Double.MinValue, Double.MaxValue, DoubleType.simpleString)(_.toDouble) + } + + /** Create a numeric literal expression. */ + private def numericLiteral( + ctx: NumberContext, + rawStrippedQualifier: String, + minValue: BigDecimal, + maxValue: BigDecimal, + typeName: String)(converter: String => Any): Literal = withOrigin(ctx) { + try { + val rawBigDecimal = BigDecimal(rawStrippedQualifier) + if (rawBigDecimal < minValue || rawBigDecimal > maxValue) { + throw new ParseException(s"Numeric literal $rawStrippedQualifier does not " + + s"fit in range [$minValue, $maxValue] for type $typeName", ctx) + } + Literal(converter(rawStrippedQualifier)) + } catch { + case e: NumberFormatException => + throw new ParseException(e.getMessage, ctx) + } + } + + /** + * Create a Byte Literal expression. + */ + override def visitTinyIntLiteral(ctx: TinyIntLiteralContext): Literal = { + val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1) + numericLiteral(ctx, rawStrippedQualifier, + Byte.MinValue, Byte.MaxValue, ByteType.simpleString)(_.toByte) + } + + /** + * Create a Short Literal expression. + */ + override def visitSmallIntLiteral(ctx: SmallIntLiteralContext): Literal = { + val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1) + numericLiteral(ctx, rawStrippedQualifier, + Short.MinValue, Short.MaxValue, ShortType.simpleString)(_.toShort) + } + + /** + * Create a Long Literal expression. + */ + override def visitBigIntLiteral(ctx: BigIntLiteralContext): Literal = { + val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1) + numericLiteral(ctx, rawStrippedQualifier, + Long.MinValue, Long.MaxValue, LongType.simpleString)(_.toLong) + } + + /** + * Create a Float Literal expression. + */ + override def visitFloatLiteral(ctx: FloatLiteralContext): Literal = { + val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1) + numericLiteral(ctx, rawStrippedQualifier, + Float.MinValue, Float.MaxValue, FloatType.simpleString)(_.toFloat) + } + + /** + * Create a Double Literal expression. + */ + override def visitDoubleLiteral(ctx: DoubleLiteralContext): Literal = { + val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1) + numericLiteral(ctx, rawStrippedQualifier, + Double.MinValue, Double.MaxValue, DoubleType.simpleString)(_.toDouble) + } + + /** + * Create a BigDecimal Literal expression. + */ + override def visitBigDecimalLiteral(ctx: BigDecimalLiteralContext): Literal = { + val raw = ctx.getText.substring(0, ctx.getText.length - 2) + try { + Literal(BigDecimal(raw).underlying()) + } catch { + case e: AnalysisException => + throw new ParseException(e.message, ctx) + } + } + + /** + * Create a String literal expression. + */ + override def visitStringLiteral(ctx: StringLiteralContext): Literal = withOrigin(ctx) { + Literal(createString(ctx)) + } + + /** + * Create a String from a string literal context. This supports multiple consecutive string + * literals, these are concatenated, for example this expression "'hello' 'world'" will be + * converted into "helloworld". + * + * Special characters can be escaped by using Hive/C-style escaping. + */ + private def createString(ctx: StringLiteralContext): String = { + if (conf.escapedStringLiterals) { + ctx.STRING().asScala.map(x => stringWithoutUnescape(x.getSymbol)).mkString + } else { + ctx.STRING().asScala.map(string).mkString + } + } + + /** + * Create an [[UnresolvedRelation]] from a multi-part identifier context. + */ + private def createUnresolvedRelation( + ctx: MultipartIdentifierContext): UnresolvedRelation = withOrigin(ctx) { + UnresolvedRelation(visitMultipartIdentifier(ctx)) + } + + /** + * Construct an [[Literal]] from [[CalendarInterval]] and + * units represented as a [[Seq]] of [[String]]. + */ + private def constructMultiUnitsIntervalLiteral( + ctx: ParserRuleContext, + calendarInterval: CalendarInterval, + units: Seq[String]): Literal = { + var yearMonthFields = Set.empty[Byte] + var dayTimeFields = Set.empty[Byte] + for (unit <- units) { + if (YearMonthIntervalType.stringToField.contains(unit)) { + yearMonthFields += YearMonthIntervalType.stringToField(unit) + } else if (DayTimeIntervalType.stringToField.contains(unit)) { + dayTimeFields += DayTimeIntervalType.stringToField(unit) + } else if (unit == "week") { + dayTimeFields += DayTimeIntervalType.DAY + } else { + assert(unit == "millisecond" || unit == "microsecond") + dayTimeFields += DayTimeIntervalType.SECOND + } + } + if (yearMonthFields.nonEmpty) { + if (dayTimeFields.nonEmpty) { + val literalStr = source(ctx) + throw new ParseException(s"Cannot mix year-month and day-time fields: $literalStr", ctx) + } + Literal( + calendarInterval.months, + YearMonthIntervalType(yearMonthFields.min, yearMonthFields.max) + ) + } else { + Literal( + IntervalUtils.getDuration(calendarInterval, TimeUnit.MICROSECONDS), + DayTimeIntervalType(dayTimeFields.min, dayTimeFields.max)) + } + } + + /** + * Create a [[CalendarInterval]] or ANSI interval literal expression. + * Two syntaxes are supported: + * - multiple unit value pairs, for instance: interval 2 months 2 days. + * - from-to unit, for instance: interval '1-2' year to month. + */ + override def visitInterval(ctx: IntervalContext): Literal = withOrigin(ctx) { + val calendarInterval = parseIntervalLiteral(ctx) + if (ctx.errorCapturingUnitToUnitInterval != null && !conf.legacyIntervalEnabled) { + // Check the `to` unit to distinguish year-month and day-time intervals because + // `CalendarInterval` doesn't have enough info. For instance, new CalendarInterval(0, 0, 0) + // can be derived from INTERVAL '0-0' YEAR TO MONTH as well as from + // INTERVAL '0 00:00:00' DAY TO SECOND. + val fromUnit = + ctx.errorCapturingUnitToUnitInterval.body.from.getText.toLowerCase(Locale.ROOT) + val toUnit = ctx.errorCapturingUnitToUnitInterval.body.to.getText.toLowerCase(Locale.ROOT) + if (toUnit == "month") { + assert(calendarInterval.days == 0 && calendarInterval.microseconds == 0) + val start = YearMonthIntervalType.stringToField(fromUnit) + Literal(calendarInterval.months, YearMonthIntervalType(start, YearMonthIntervalType.MONTH)) + } else { + assert(calendarInterval.months == 0) + val micros = IntervalUtils.getDuration(calendarInterval, TimeUnit.MICROSECONDS) + val start = DayTimeIntervalType.stringToField(fromUnit) + val end = DayTimeIntervalType.stringToField(toUnit) + Literal(micros, DayTimeIntervalType(start, end)) + } + } else if (ctx.errorCapturingMultiUnitsInterval != null && !conf.legacyIntervalEnabled) { + val units = + ctx.errorCapturingMultiUnitsInterval.body.unit.asScala.map( + _.getText.toLowerCase(Locale.ROOT).stripSuffix("s")).toSeq + constructMultiUnitsIntervalLiteral(ctx, calendarInterval, units) + } else { + Literal(calendarInterval, CalendarIntervalType) + } + } + + /** + * Create a [[CalendarInterval]] object + */ + protected def parseIntervalLiteral(ctx: IntervalContext): CalendarInterval = withOrigin(ctx) { + if (ctx.errorCapturingMultiUnitsInterval != null) { + val innerCtx = ctx.errorCapturingMultiUnitsInterval + if (innerCtx.unitToUnitInterval != null) { + throw new ParseException("Can only have a single from-to unit in the interval literal syntax", innerCtx.unitToUnitInterval) + } + visitMultiUnitsInterval(innerCtx.multiUnitsInterval) + } else if (ctx.errorCapturingUnitToUnitInterval != null) { + val innerCtx = ctx.errorCapturingUnitToUnitInterval + if (innerCtx.error1 != null || innerCtx.error2 != null) { + val errorCtx = if (innerCtx.error1 != null) innerCtx.error1 else innerCtx.error2 + throw new ParseException("Can only have a single from-to unit in the interval literal syntax", errorCtx) + } + visitUnitToUnitInterval(innerCtx.body) + } else { + throw new ParseException("at least one time unit should be given for interval literal", ctx) + } + } + + /** + * Creates a [[CalendarInterval]] with multiple unit value pairs, e.g. 1 YEAR 2 DAYS. + */ + override def visitMultiUnitsInterval(ctx: MultiUnitsIntervalContext): CalendarInterval = { + withOrigin(ctx) { + val units = ctx.unit.asScala + val values = ctx.intervalValue().asScala + try { + assert(units.length == values.length) + val kvs = units.indices.map { i => + val u = units(i).getText + val v = if (values(i).STRING() != null) { + val value = string(values(i).STRING()) + // SPARK-32840: For invalid cases, e.g. INTERVAL '1 day 2' hour, + // INTERVAL 'interval 1' day, we need to check ahead before they are concatenated with + // units and become valid ones, e.g. '1 day 2 hour'. + // Ideally, we only ensure the value parts don't contain any units here. + if (value.exists(Character.isLetter)) { + throw new ParseException("Can only use numbers in the interval value part for" + + s" multiple unit value pairs interval form, but got invalid value: $value", ctx) + } + if (values(i).MINUS() == null) { + value + } else { + value.startsWith("-") match { + case true => value.replaceFirst("-", "") + case false => s"-$value" + } + } + } else { + values(i).getText + } + UTF8String.fromString(" " + v + " " + u) + } + IntervalUtils.stringToInterval(UTF8String.concat(kvs: _*)) + } catch { + case i: IllegalArgumentException => + val e = new ParseException(i.getMessage, ctx) + e.setStackTrace(i.getStackTrace) + throw e + } + } + } + + /** + * Creates a [[CalendarInterval]] with from-to unit, e.g. '2-1' YEAR TO MONTH. + */ + override def visitUnitToUnitInterval(ctx: UnitToUnitIntervalContext): CalendarInterval = { + withOrigin(ctx) { + val value = Option(ctx.intervalValue.STRING).map(string).map { interval => + if (ctx.intervalValue().MINUS() == null) { + interval + } else { + interval.startsWith("-") match { + case true => interval.replaceFirst("-", "") + case false => s"-$interval" + } + } + }.getOrElse { + throw new ParseException("The value of from-to unit must be a string", ctx.intervalValue) + } + try { + val from = ctx.from.getText.toLowerCase(Locale.ROOT) + val to = ctx.to.getText.toLowerCase(Locale.ROOT) + (from, to) match { + case ("year", "month") => + IntervalUtils.fromYearMonthString(value) + case ("day", "hour") | ("day", "minute") | ("day", "second") | ("hour", "minute") | + ("hour", "second") | ("minute", "second") => + IntervalUtils.fromDayTimeString(value, + DayTimeIntervalType.stringToField(from), DayTimeIntervalType.stringToField(to)) + case _ => + throw new ParseException(s"Intervals FROM $from TO $to are not supported.", ctx) + } + } catch { + // Handle Exceptions thrown by CalendarInterval + case e: IllegalArgumentException => + val pe = new ParseException(e.getMessage, ctx) + pe.setStackTrace(e.getStackTrace) + throw pe + } + } + } + + /* ******************************************************************************************** + * DataType parsing + * ******************************************************************************************** */ + + /** + * Resolve/create a primitive type. + */ + override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = withOrigin(ctx) { + val dataType = ctx.identifier.getText.toLowerCase(Locale.ROOT) + (dataType, ctx.INTEGER_VALUE().asScala.toList) match { + case ("boolean", Nil) => BooleanType + case ("tinyint" | "byte", Nil) => ByteType + case ("smallint" | "short", Nil) => ShortType + case ("int" | "integer", Nil) => IntegerType + case ("bigint" | "long", Nil) => LongType + case ("float" | "real", Nil) => FloatType + case ("double", Nil) => DoubleType + case ("date", Nil) => DateType + case ("timestamp", Nil) => SQLConf.get.timestampType + // SPARK-36227: Remove TimestampNTZ type support in Spark 3.2 with minimal code changes. + case ("timestamp_ntz", Nil) if isTesting => TimestampNTZType + case ("timestamp_ltz", Nil) if isTesting => TimestampType + case ("string", Nil) => StringType + case ("character" | "char", length :: Nil) => CharType(length.getText.toInt) + case ("varchar", length :: Nil) => VarcharType(length.getText.toInt) + case ("binary", Nil) => BinaryType + case ("decimal" | "dec" | "numeric", Nil) => DecimalType.USER_DEFAULT + case ("decimal" | "dec" | "numeric", precision :: Nil) => + DecimalType(precision.getText.toInt, 0) + case ("decimal" | "dec" | "numeric", precision :: scale :: Nil) => + DecimalType(precision.getText.toInt, scale.getText.toInt) + case ("void", Nil) => NullType + case ("interval", Nil) => CalendarIntervalType + case (dt, params) => + val dtStr = if (params.nonEmpty) s"$dt(${params.mkString(",")})" else dt + throw new ParseException(s"DataType $dtStr is not supported.", ctx) + } + } + + override def visitYearMonthIntervalDataType(ctx: YearMonthIntervalDataTypeContext): DataType = { + val startStr = ctx.from.getText.toLowerCase(Locale.ROOT) + val start = YearMonthIntervalType.stringToField(startStr) + if (ctx.to != null) { + val endStr = ctx.to.getText.toLowerCase(Locale.ROOT) + val end = YearMonthIntervalType.stringToField(endStr) + if (end <= start) { + throw new ParseException(s"Intervals FROM $startStr TO $endStr are not supported.", ctx) + } + YearMonthIntervalType(start, end) + } else { + YearMonthIntervalType(start) + } + } + + override def visitDayTimeIntervalDataType(ctx: DayTimeIntervalDataTypeContext): DataType = { + val startStr = ctx.from.getText.toLowerCase(Locale.ROOT) + val start = DayTimeIntervalType.stringToField(startStr) + if (ctx.to != null) { + val endStr = ctx.to.getText.toLowerCase(Locale.ROOT) + val end = DayTimeIntervalType.stringToField(endStr) + if (end <= start) { + throw new ParseException(s"Intervals FROM $startStr TO $endStr are not supported.", ctx) + } + DayTimeIntervalType(start, end) + } else { + DayTimeIntervalType(start) + } + } + + /** + * Create a complex DataType. Arrays, Maps and Structures are supported. + */ + override def visitComplexDataType(ctx: ComplexDataTypeContext): DataType = withOrigin(ctx) { + ctx.complex.getType match { + case HoodieSqlBaseParser.ARRAY => + ArrayType(typedVisit(ctx.dataType(0))) + case HoodieSqlBaseParser.MAP => + MapType(typedVisit(ctx.dataType(0)), typedVisit(ctx.dataType(1))) + case HoodieSqlBaseParser.STRUCT => + StructType(Option(ctx.complexColTypeList).toSeq.flatMap(visitComplexColTypeList)) + } + } + + /** + * Create top level table schema. + */ + protected def createSchema(ctx: ColTypeListContext): StructType = { + StructType(Option(ctx).toSeq.flatMap(visitColTypeList)) + } + + /** + * Create a [[StructType]] from a number of column definitions. + */ + override def visitColTypeList(ctx: ColTypeListContext): Seq[StructField] = withOrigin(ctx) { + ctx.colType().asScala.map(visitColType).toSeq + } + + /** + * Create a top level [[StructField]] from a column definition. + */ + override def visitColType(ctx: ColTypeContext): StructField = withOrigin(ctx) { + import ctx._ + + val builder = new MetadataBuilder + // Add comment to metadata + Option(commentSpec()).map(visitCommentSpec).foreach { + builder.putString("comment", _) + } + + StructField( + name = colName.getText, + dataType = typedVisit[DataType](ctx.dataType), + nullable = NULL == null, + metadata = builder.build()) + } + + /** + * Create a [[StructType]] from a sequence of [[StructField]]s. + */ + protected def createStructType(ctx: ComplexColTypeListContext): StructType = { + StructType(Option(ctx).toSeq.flatMap(visitComplexColTypeList)) + } + + /** + * Create a [[StructType]] from a number of column definitions. + */ + override def visitComplexColTypeList( + ctx: ComplexColTypeListContext): Seq[StructField] = withOrigin(ctx) { + ctx.complexColType().asScala.map(visitComplexColType).toSeq + } + + /** + * Create a [[StructField]] from a column definition. + */ + override def visitComplexColType(ctx: ComplexColTypeContext): StructField = withOrigin(ctx) { + import ctx._ + val structField = StructField( + name = identifier.getText, + dataType = typedVisit(dataType()), + nullable = NULL == null) + Option(commentSpec).map(visitCommentSpec).map(structField.withComment).getOrElse(structField) + } + + /** + * Create a location string. + */ + override def visitLocationSpec(ctx: LocationSpecContext): String = withOrigin(ctx) { + string(ctx.STRING) + } + + /** + * Create an optional location string. + */ + protected def visitLocationSpecList(ctx: java.util.List[LocationSpecContext]): Option[String] = { + ctx.asScala.headOption.map(visitLocationSpec) + } + + /** + * Create a comment string. + */ + override def visitCommentSpec(ctx: CommentSpecContext): String = withOrigin(ctx) { + string(ctx.STRING) + } + + /** + * Create an optional comment string. + */ + protected def visitCommentSpecList(ctx: java.util.List[CommentSpecContext]): Option[String] = { + ctx.asScala.headOption.map(visitCommentSpec) + } + + /** + * Create a [[BucketSpec]]. + */ + override def visitBucketSpec(ctx: BucketSpecContext): BucketSpec = withOrigin(ctx) { + BucketSpec( + ctx.INTEGER_VALUE.getText.toInt, + visitIdentifierList(ctx.identifierList), + Option(ctx.orderedIdentifierList) + .toSeq + .flatMap(_.orderedIdentifier.asScala) + .map { orderedIdCtx => + Option(orderedIdCtx.ordering).map(_.getText).foreach { dir => + if (dir.toLowerCase(Locale.ROOT) != "asc") { + operationNotAllowed(s"Column ordering must be ASC, was '$dir'", ctx) + } + } + + orderedIdCtx.ident.getText + }) + } + + /** + * Convert a table property list into a key-value map. + * This should be called through [[visitPropertyKeyValues]] or [[visitPropertyKeys]]. + */ + override def visitTablePropertyList( + ctx: TablePropertyListContext): Map[String, String] = withOrigin(ctx) { + val properties = ctx.tableProperty.asScala.map { property => + val key = visitTablePropertyKey(property.key) + val value = visitTablePropertyValue(property.value) + key -> value + } + // Check for duplicate property names. + checkDuplicateKeys(properties.toSeq, ctx) + properties.toMap + } + + /** + * Parse a key-value map from a [[TablePropertyListContext]], assuming all values are specified. + */ + def visitPropertyKeyValues(ctx: TablePropertyListContext): Map[String, String] = { + val props = visitTablePropertyList(ctx) + val badKeys = props.collect { case (key, null) => key } + if (badKeys.nonEmpty) { + operationNotAllowed( + s"Values must be specified for key(s): ${badKeys.mkString("[", ",", "]")}", ctx) + } + props + } + + /** + * Parse a list of keys from a [[TablePropertyListContext]], assuming no values are specified. + */ + def visitPropertyKeys(ctx: TablePropertyListContext): Seq[String] = { + val props = visitTablePropertyList(ctx) + val badKeys = props.filter { case (_, v) => v != null }.keys + if (badKeys.nonEmpty) { + operationNotAllowed( + s"Values should not be specified for key(s): ${badKeys.mkString("[", ",", "]")}", ctx) + } + props.keys.toSeq + } + + /** + * A table property key can either be String or a collection of dot separated elements. This + * function extracts the property key based on whether its a string literal or a table property + * identifier. + */ + override def visitTablePropertyKey(key: TablePropertyKeyContext): String = { + if (key.STRING != null) { + string(key.STRING) + } else { + key.getText + } + } + + /** + * A table property value can be String, Integer, Boolean or Decimal. This function extracts + * the property value based on whether its a string, integer, boolean or decimal literal. + */ + override def visitTablePropertyValue(value: TablePropertyValueContext): String = { + if (value == null) { + null + } else if (value.STRING != null) { + string(value.STRING) + } else if (value.booleanValue != null) { + value.getText.toLowerCase(Locale.ROOT) + } else { + value.getText + } + } + + /** + * Type to keep track of a table header: (identifier, isTemporary, ifNotExists, isExternal). + */ + type TableHeader = (Seq[String], Boolean, Boolean, Boolean) + + /** + * Type to keep track of table clauses: + * - partition transforms + * - partition columns + * - bucketSpec + * - properties + * - options + * - location + * - comment + * - serde + * + * Note: Partition transforms are based on existing table schema definition. It can be simple + * column names, or functions like `year(date_col)`. Partition columns are column names with data + * types like `i INT`, which should be appended to the existing table schema. + */ + type TableClauses = ( + Seq[Transform], Seq[StructField], Option[BucketSpec], Map[String, String], + Map[String, String], Option[String], Option[String], Option[SerdeInfo]) + + /** + * Validate a create table statement and return the [[TableIdentifier]]. + */ + override def visitCreateTableHeader( + ctx: CreateTableHeaderContext): TableHeader = withOrigin(ctx) { + val temporary = ctx.TEMPORARY != null + val ifNotExists = ctx.EXISTS != null + if (temporary && ifNotExists) { + operationNotAllowed("CREATE TEMPORARY TABLE ... IF NOT EXISTS", ctx) + } + val multipartIdentifier = ctx.multipartIdentifier.parts.asScala.map(_.getText).toSeq + (multipartIdentifier, temporary, ifNotExists, ctx.EXTERNAL != null) + } + + /** + * Validate a replace table statement and return the [[TableIdentifier]]. + */ + override def visitReplaceTableHeader( + ctx: ReplaceTableHeaderContext): TableHeader = withOrigin(ctx) { + val multipartIdentifier = ctx.multipartIdentifier.parts.asScala.map(_.getText).toSeq + (multipartIdentifier, false, false, false) + } + + /** + * Parse a qualified name to a multipart name. + */ + override def visitQualifiedName(ctx: QualifiedNameContext): Seq[String] = withOrigin(ctx) { + ctx.identifier.asScala.map(_.getText).toSeq + } + + /** + * Parse a list of transforms or columns. + */ + override def visitPartitionFieldList( + ctx: PartitionFieldListContext): (Seq[Transform], Seq[StructField]) = withOrigin(ctx) { + val (transforms, columns) = ctx.fields.asScala.map { + case transform: PartitionTransformContext => + (Some(visitPartitionTransform(transform)), None) + case field: PartitionColumnContext => + (None, Some(visitColType(field.colType))) + }.unzip + + (transforms.flatten.toSeq, columns.flatten.toSeq) + } + + override def visitPartitionTransform( + ctx: PartitionTransformContext): Transform = withOrigin(ctx) { + def getFieldReference( + ctx: ApplyTransformContext, + arg: V2Expression): FieldReference = { + lazy val name: String = ctx.identifier.getText + arg match { + case ref: FieldReference => + ref + case nonRef => + throw new ParseException(s"Expected a column reference for transform $name: $nonRef.describe", ctx) + } + } + + def getSingleFieldReference( + ctx: ApplyTransformContext, + arguments: Seq[V2Expression]): FieldReference = { + lazy val name: String = ctx.identifier.getText + if (arguments.size > 1) { + throw new ParseException(s"Too many arguments for transform $name", ctx) + } else if (arguments.isEmpty) { + throw + + new ParseException(s"Not enough arguments for transform $name", ctx) + } else { + getFieldReference(ctx, arguments.head) + } + } + + ctx.transform match { + case identityCtx: IdentityTransformContext => + IdentityTransform(FieldReference(typedVisit[Seq[String]](identityCtx.qualifiedName))) + + case applyCtx: ApplyTransformContext => + val arguments = applyCtx.argument.asScala.map(visitTransformArgument).toSeq + + applyCtx.identifier.getText match { + case "bucket" => + val numBuckets: Int = arguments.head match { + case LiteralValue(shortValue, ShortType) => + shortValue.asInstanceOf[Short].toInt + case LiteralValue(intValue, IntegerType) => + intValue.asInstanceOf[Int] + case LiteralValue(longValue, LongType) => + longValue.asInstanceOf[Long].toInt + case lit => + throw new ParseException(s"Invalid number of buckets: ${lit.describe}", applyCtx) + } + + val fields = arguments.tail.map(arg => getFieldReference(applyCtx, arg)) + + BucketTransform(LiteralValue(numBuckets, IntegerType), fields) + + case "years" => + YearsTransform(getSingleFieldReference(applyCtx, arguments)) + + case "months" => + MonthsTransform(getSingleFieldReference(applyCtx, arguments)) + + case "days" => + DaysTransform(getSingleFieldReference(applyCtx, arguments)) + + case "hours" => + HoursTransform(getSingleFieldReference(applyCtx, arguments)) + + case name => + ApplyTransform(name, arguments) + } + } + } + + /** + * Parse an argument to a transform. An argument may be a field reference (qualified name) or + * a value literal. + */ + override def visitTransformArgument(ctx: TransformArgumentContext): V2Expression = { + withOrigin(ctx) { + val reference = Option(ctx.qualifiedName) + .map(typedVisit[Seq[String]]) + .map(FieldReference(_)) + val literal = Option(ctx.constant) + .map(typedVisit[Literal]) + .map(lit => LiteralValue(lit.value, lit.dataType)) + reference.orElse(literal) + .getOrElse(throw new ParseException("Invalid transform argument", ctx)) + } + } + + def cleanTableProperties( + ctx: ParserRuleContext, properties: Map[String, String]): Map[String, String] = { + import TableCatalog._ + val legacyOn = conf.getConf(SQLConf.LEGACY_PROPERTY_NON_RESERVED) + properties.filter { + case (PROP_PROVIDER, _) if !legacyOn => + throw new ParseException(s"$PROP_PROVIDER is a reserved table property, please use the USING clause to specify it.", ctx) + case (PROP_PROVIDER, _) => false + case (PROP_LOCATION, _) if !legacyOn => + throw new ParseException(s"$PROP_LOCATION is a reserved table property, please use the LOCATION clause to specify it.", ctx) + case (PROP_LOCATION, _) => false + case (PROP_OWNER, _) if !legacyOn => + throw new ParseException(s"$PROP_OWNER is a reserved table property, it will be set to the current user.", ctx) + case (PROP_OWNER, _) => false + case _ => true + } + } + + def cleanTableOptions( + ctx: ParserRuleContext, + options: Map[String, String], + location: Option[String]): (Map[String, String], Option[String]) = { + var path = location + val filtered = cleanTableProperties(ctx, options).filter { + case (k, v) if k.equalsIgnoreCase("path") && path.nonEmpty => + throw new ParseException(s"Duplicated table paths found: '${path.get}' and '$v'. LOCATION" + + s" and the case insensitive key 'path' in OPTIONS are all used to indicate the custom" + + s" table path, you can only specify one of them.", ctx) + case (k, v) if k.equalsIgnoreCase("path") => + path = Some(v) + false + case _ => true + } + (filtered, path) + } + + /** + * Create a [[SerdeInfo]] for creating tables. + * + * Format: STORED AS (name | INPUTFORMAT input_format OUTPUTFORMAT output_format) + */ + override def visitCreateFileFormat(ctx: CreateFileFormatContext): SerdeInfo = withOrigin(ctx) { + (ctx.fileFormat, ctx.storageHandler) match { + // Expected format: INPUTFORMAT input_format OUTPUTFORMAT output_format + case (c: TableFileFormatContext, null) => + SerdeInfo(formatClasses = Some(FormatClasses(string(c.inFmt), string(c.outFmt)))) + // Expected format: SEQUENCEFILE | TEXTFILE | RCFILE | ORC | PARQUET | AVRO + case (c: GenericFileFormatContext, null) => + SerdeInfo(storedAs = Some(c.identifier.getText)) + case (null, storageHandler) => + operationNotAllowed("STORED BY", ctx) + case _ => + throw new ParseException("Expected either STORED AS or STORED BY, not both", ctx) + } + } + + /** + * Create a [[SerdeInfo]] used for creating tables. + * + * Example format: + * {{{ + * SERDE serde_name [WITH SERDEPROPERTIES (k1=v1, k2=v2, ...)] + * }}} + * + * OR + * + * {{{ + * DELIMITED [FIELDS TERMINATED BY char [ESCAPED BY char]] + * [COLLECTION ITEMS TERMINATED BY char] + * [MAP KEYS TERMINATED BY char] + * [LINES TERMINATED BY char] + * [NULL DEFINED AS char] + * }}} + */ + def visitRowFormat(ctx: RowFormatContext): SerdeInfo = withOrigin(ctx) { + ctx match { + case serde: RowFormatSerdeContext => visitRowFormatSerde(serde) + case delimited: RowFormatDelimitedContext => visitRowFormatDelimited(delimited) + } + } + + /** + * Create SERDE row format name and properties pair. + */ + override def visitRowFormatSerde(ctx: RowFormatSerdeContext): SerdeInfo = withOrigin(ctx) { + import ctx._ + SerdeInfo( + serde = Some(string(name)), + serdeProperties = Option(tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty)) + } + + /** + * Create a delimited row format properties object. + */ + override def visitRowFormatDelimited( + ctx: RowFormatDelimitedContext): SerdeInfo = withOrigin(ctx) { + // Collect the entries if any. + def entry(key: String, value: Token): Seq[(String, String)] = { + Option(value).toSeq.map(x => key -> string(x)) + } + + // TODO we need proper support for the NULL format. + val entries = + entry("field.delim", ctx.fieldsTerminatedBy) ++ + entry("serialization.format", ctx.fieldsTerminatedBy) ++ + entry("escape.delim", ctx.escapedBy) ++ + // The following typo is inherited from Hive... + entry("colelction.delim", ctx.collectionItemsTerminatedBy) ++ + entry("mapkey.delim", ctx.keysTerminatedBy) ++ + Option(ctx.linesSeparatedBy).toSeq.map { token => + val value = string(token) + validate( + value == "\n", + s"LINES TERMINATED BY only supports newline '\\n' right now: $value", + ctx) + "line.delim" -> value + } + SerdeInfo(serdeProperties = entries.toMap) + } + + /** + * Throw a [[ParseException]] if the user specified incompatible SerDes through ROW FORMAT + * and STORED AS. + * + * The following are allowed. Anything else is not: + * ROW FORMAT SERDE ... STORED AS [SEQUENCEFILE | RCFILE | TEXTFILE] + * ROW FORMAT DELIMITED ... STORED AS TEXTFILE + * ROW FORMAT ... STORED AS INPUTFORMAT ... OUTPUTFORMAT ... + */ + protected def validateRowFormatFileFormat( + rowFormatCtx: RowFormatContext, + createFileFormatCtx: CreateFileFormatContext, + parentCtx: ParserRuleContext): Unit = { + if (!(rowFormatCtx == null || createFileFormatCtx == null)) { + (rowFormatCtx, createFileFormatCtx.fileFormat) match { + case (_, ffTable: TableFileFormatContext) => // OK + case (rfSerde: RowFormatSerdeContext, ffGeneric: GenericFileFormatContext) => + ffGeneric.identifier.getText.toLowerCase(Locale.ROOT) match { + case ("sequencefile" | "textfile" | "rcfile") => // OK + case fmt => + operationNotAllowed( + s"ROW FORMAT SERDE is incompatible with format '$fmt', which also specifies a serde", + parentCtx) + } + case (rfDelimited: RowFormatDelimitedContext, ffGeneric: GenericFileFormatContext) => + ffGeneric.identifier.getText.toLowerCase(Locale.ROOT) match { + case "textfile" => // OK + case fmt => operationNotAllowed( + s"ROW FORMAT DELIMITED is only compatible with 'textfile', not '$fmt'", parentCtx) + } + case _ => + // should never happen + def str(ctx: ParserRuleContext): String = { + (0 until ctx.getChildCount).map { i => ctx.getChild(i).getText }.mkString(" ") + } + + operationNotAllowed( + s"Unexpected combination of ${str(rowFormatCtx)} and ${str(createFileFormatCtx)}", + parentCtx) + } + } + } + + protected def validateRowFormatFileFormat( + rowFormatCtx: Seq[RowFormatContext], + createFileFormatCtx: Seq[CreateFileFormatContext], + parentCtx: ParserRuleContext): Unit = { + if (rowFormatCtx.size == 1 && createFileFormatCtx.size == 1) { + validateRowFormatFileFormat(rowFormatCtx.head, createFileFormatCtx.head, parentCtx) + } + } + + override def visitCreateTableClauses(ctx: CreateTableClausesContext): TableClauses = { + checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx) + checkDuplicateClauses(ctx.OPTIONS, "OPTIONS", ctx) + checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx) + checkDuplicateClauses(ctx.createFileFormat, "STORED AS/BY", ctx) + checkDuplicateClauses(ctx.rowFormat, "ROW FORMAT", ctx) + checkDuplicateClauses(ctx.commentSpec(), "COMMENT", ctx) + checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx) + checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx) + + if (ctx.skewSpec.size > 0) { + operationNotAllowed("CREATE TABLE ... SKEWED BY", ctx) + } + + val (partTransforms, partCols) = + Option(ctx.partitioning).map(visitPartitionFieldList).getOrElse((Nil, Nil)) + val bucketSpec = ctx.bucketSpec().asScala.headOption.map(visitBucketSpec) + val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty) + val cleanedProperties = cleanTableProperties(ctx, properties) + val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty) + val location = visitLocationSpecList(ctx.locationSpec()) + val (cleanedOptions, newLocation) = cleanTableOptions(ctx, options, location) + val comment = visitCommentSpecList(ctx.commentSpec()) + val serdeInfo = + getSerdeInfo(ctx.rowFormat.asScala.toSeq, ctx.createFileFormat.asScala.toSeq, ctx) + (partTransforms, partCols, bucketSpec, cleanedProperties, cleanedOptions, newLocation, comment, + serdeInfo) + } + + protected def getSerdeInfo( + rowFormatCtx: Seq[RowFormatContext], + createFileFormatCtx: Seq[CreateFileFormatContext], + ctx: ParserRuleContext): Option[SerdeInfo] = { + validateRowFormatFileFormat(rowFormatCtx, createFileFormatCtx, ctx) + val rowFormatSerdeInfo = rowFormatCtx.map(visitRowFormat) + val fileFormatSerdeInfo = createFileFormatCtx.map(visitCreateFileFormat) + (fileFormatSerdeInfo ++ rowFormatSerdeInfo).reduceLeftOption((l, r) => l.merge(r)) + } + + private def partitionExpressions( + partTransforms: Seq[Transform], + partCols: Seq[StructField], + ctx: ParserRuleContext): Seq[Transform] = { + if (partTransforms.nonEmpty) { + if (partCols.nonEmpty) { + val references = partTransforms.map(_.describe()).mkString(", ") + val columns = partCols + .map(field => s"${field.name} ${field.dataType.simpleString}") + .mkString(", ") + operationNotAllowed( + s"""PARTITION BY: Cannot mix partition expressions and partition columns: + |Expressions: $references + |Columns: $columns""".stripMargin, ctx) + + } + partTransforms + } else { + // columns were added to create the schema. convert to column references + partCols.map { column => + IdentityTransform(FieldReference(Seq(column.name))) + } + } + } + + /** + * Create a table, returning a [[CreateTable]] or [[CreateTableAsSelect]] logical plan. + * + * Expected format: + * {{{ + * CREATE [TEMPORARY] TABLE [IF NOT EXISTS] [db_name.]table_name + * [USING table_provider] + * create_table_clauses + * [[AS] select_statement]; + * + * create_table_clauses (order insensitive): + * [PARTITIONED BY (partition_fields)] + * [OPTIONS table_property_list] + * [ROW FORMAT row_format] + * [STORED AS file_format] + * [CLUSTERED BY (col_name, col_name, ...) + * [SORTED BY (col_name [ASC|DESC], ...)] + * INTO num_buckets BUCKETS + * ] + * [LOCATION path] + * [COMMENT table_comment] + * [TBLPROPERTIES (property_name=property_value, ...)] + * + * partition_fields: + * col_name, transform(col_name), transform(constant, col_name), ... | + * col_name data_type [NOT NULL] [COMMENT col_comment], ... + * }}} + */ + override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = withOrigin(ctx) { + val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) + + val columns = Option(ctx.colTypeList()).map(visitColTypeList).getOrElse(Nil) + val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText) + val (partTransforms, partCols, bucketSpec, properties, options, location, comment, serdeInfo) = + visitCreateTableClauses(ctx.createTableClauses()) + + if (provider.isDefined && serdeInfo.isDefined) { + operationNotAllowed(s"CREATE TABLE ... USING ... ${serdeInfo.get.describe}", ctx) + } + + if (temp) { + val asSelect = if (ctx.query == null) "" else " AS ..." + operationNotAllowed( + s"CREATE TEMPORARY TABLE ...$asSelect, use CREATE TEMPORARY VIEW instead", ctx) + } + + // partition transforms for BucketSpec was moved inside parser + // https://issues.apache.org/jira/browse/SPARK-37923 + val partitioning = + partitionExpressions(partTransforms, partCols, ctx) ++ bucketSpec.map(_.asTransform) + val tableSpec = TableSpec(properties, provider, options, location, comment, + serdeInfo, external) + + Option(ctx.query).map(plan) match { + case Some(_) if columns.nonEmpty => + operationNotAllowed( + "Schema may not be specified in a Create Table As Select (CTAS) statement", + ctx) + + case Some(_) if partCols.nonEmpty => + // non-reference partition columns are not allowed because schema can't be specified + operationNotAllowed( + "Partition column types may not be specified in Create Table As Select (CTAS)", + ctx) + + // CreateTable / CreateTableAsSelect was migrated to v2 in Spark 3.3.0 + // https://issues.apache.org/jira/browse/SPARK-36850 + case Some(query) => + CreateTableAsSelect( + UnresolvedIdentifier(table), + partitioning, query, tableSpec, Map.empty, ifNotExists) + + case _ => + // Note: table schema includes both the table columns list and the partition columns + // with data type. + val schema = StructType(columns ++ partCols) + CreateTable( + UnresolvedIdentifier(table), + schema, partitioning, tableSpec, ignoreIfExists = ifNotExists) + } + } + + /** + * Parse new column info from ADD COLUMN into a QualifiedColType. + */ + override def visitQualifiedColTypeWithPosition( + ctx: QualifiedColTypeWithPositionContext): QualifiedColType = withOrigin(ctx) { + val name = typedVisit[Seq[String]](ctx.name) + QualifiedColType( + path = if (name.length > 1) Some(UnresolvedFieldName(name.init)) else None, + colName = name.last, + dataType = typedVisit[DataType](ctx.dataType), + nullable = ctx.NULL == null, + comment = Option(ctx.commentSpec()).map(visitCommentSpec), + position = Option(ctx.colPosition).map(pos => + UnresolvedFieldPosition(typedVisit[ColumnPosition](pos))), + default = Option(null)) + } + + + /** + * Create an index, returning a [[CreateIndex]] logical plan. + * For example: + * {{{ + * CREATE INDEX index_name ON [TABLE] table_name [USING index_type] (column_index_property_list) + * [OPTIONS indexPropertyList] + * column_index_property_list: column_name [OPTIONS(indexPropertyList)] [ , . . . ] + * indexPropertyList: index_property_name [= index_property_value] [ , . . . ] + * }}} + */ + override def visitCreateIndex(ctx: CreateIndexContext): LogicalPlan = withOrigin(ctx) { + val (indexName, indexType) = if (ctx.identifier.size() == 1) { + (ctx.identifier(0).getText, "") + } else { + (ctx.identifier(0).getText, ctx.identifier(1).getText) + } + + val columns = ctx.columns.multipartIdentifierProperty.asScala + .map(_.multipartIdentifier).map(typedVisit[Seq[String]]).toSeq + val columnsProperties = ctx.columns.multipartIdentifierProperty.asScala + .map(x => (Option(x.options).map(visitPropertyKeyValues).getOrElse(Map.empty))).toSeq + val options = Option(ctx.indexOptions).map(visitPropertyKeyValues).getOrElse(Map.empty) + + CreateIndex( + UnresolvedRelation(visitTableIdentifier(ctx.tableIdentifier())), + indexName, + indexType, + ctx.EXISTS != null, + columns.map(UnresolvedFieldName).zip(columnsProperties), + options) + } + + /** + * Drop an index, returning a [[DropIndex]] logical plan. + * For example: + * {{{ + * DROP INDEX [IF EXISTS] index_name ON [TABLE] table_name + * }}} + */ + override def visitDropIndex(ctx: DropIndexContext): LogicalPlan = withOrigin(ctx) { + val indexName = ctx.identifier.getText + DropIndex( + UnresolvedRelation(visitTableIdentifier(ctx.tableIdentifier())), + indexName, + ctx.EXISTS != null) + } + + /** + * Show indexes, returning a [[ShowIndexes]] logical plan. + * For example: + * {{{ + * SHOW INDEXES (FROM | IN) [TABLE] table_name + * }}} + */ + override def visitShowIndexes(ctx: ShowIndexesContext): LogicalPlan = withOrigin(ctx) { + ShowIndexes(UnresolvedRelation(visitTableIdentifier(ctx.tableIdentifier()))) + } + + /** + * Refresh index, returning a [[RefreshIndex]] logical plan + * For example: + * {{{ + * REFRESH INDEX index_name ON [TABLE] table_name + * }}} + */ + override def visitRefreshIndex(ctx: RefreshIndexContext): LogicalPlan = withOrigin(ctx) { + RefreshIndex(UnresolvedRelation(visitTableIdentifier(ctx.tableIdentifier())), ctx.identifier.getText) + } + + /** + * Convert a property list into a key-value map. + * This should be called through [[visitPropertyKeyValues]] or [[visitPropertyKeys]]. + */ + override def visitPropertyList(ctx: PropertyListContext): Map[String, String] = withOrigin(ctx) { + val properties = ctx.property.asScala.map { property => + val key = visitPropertyKey(property.key) + val value = visitPropertyValue(property.value) + key -> value + } + // Check for duplicate property names. + checkDuplicateKeys(properties.toSeq, ctx) + properties.toMap + } + + /** + * Parse a key-value map from a [[PropertyListContext]], assuming all values are specified. + */ + def visitPropertyKeyValues(ctx: PropertyListContext): Map[String, String] = { + val props = visitPropertyList(ctx) + val badKeys = props.collect { case (key, null) => key } + if (badKeys.nonEmpty) { + operationNotAllowed( + s"Values must be specified for key(s): ${badKeys.mkString("[", ",", "]")}", ctx) + } + props + } + + /** + * Parse a list of keys from a [[PropertyListContext]], assuming no values are specified. + */ + def visitPropertyKeys(ctx: PropertyListContext): Seq[String] = { + val props = visitPropertyList(ctx) + val badKeys = props.filter { case (_, v) => v != null }.keys + if (badKeys.nonEmpty) { + operationNotAllowed( + s"Values should not be specified for key(s): ${badKeys.mkString("[", ",", "]")}", ctx) + } + props.keys.toSeq + } + + /** + * A property key can either be String or a collection of dot separated elements. This + * function extracts the property key based on whether its a string literal or a property + * identifier. + */ + override def visitPropertyKey(key: PropertyKeyContext): String = { + if (key.STRING != null) { + string(key.STRING) + } else { + key.getText + } + } + + /** + * A property value can be String, Integer, Boolean or Decimal. This function extracts + * the property value based on whether its a string, integer, boolean or decimal literal. + */ + override def visitPropertyValue(value: PropertyValueContext): String = { + if (value == null) { + null + } else if (value.STRING != null) { + string(value.STRING) + } else if (value.booleanValue != null) { + value.getText.toLowerCase(Locale.ROOT) + } else { + value.getText + } + } +} + +/** + * A container for holding named common table expressions (CTEs) and a query plan. + * This operator will be removed during analysis and the relations will be substituted into child. + * + * @param child The final query of this CTE. + * @param cteRelations A sequence of pair (alias, the CTE definition) that this CTE defined + * Each CTE can see the base tables and the previously defined CTEs only. + */ +case class With(child: LogicalPlan, cteRelations: Seq[(String, SubqueryAlias)]) extends UnaryNode { + override def output: Seq[Attribute] = child.output + + override def simpleString(maxFields: Int): String = { + val cteAliases = truncatedString(cteRelations.map(_._1), "[", ", ", "]", maxFields) + s"CTE $cteAliases" + } + + override def innerChildren: Seq[LogicalPlan] = cteRelations.map(_._2) + + def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = this +} diff --git a/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_5ExtendedSqlParser.scala b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_5ExtendedSqlParser.scala new file mode 100644 index 0000000000000..bbde7bea5538b --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.5.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_5ExtendedSqlParser.scala @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parser + +import org.antlr.v4.runtime._ +import org.antlr.v4.runtime.atn.PredictionMode +import org.antlr.v4.runtime.misc.{Interval, ParseCancellationException} +import org.antlr.v4.runtime.tree.TerminalNodeImpl +import org.apache.hudi.spark.sql.parser.HoodieSqlBaseParser.{NonReservedContext, QuotedIdentifierContext} +import org.apache.hudi.spark.sql.parser.{HoodieSqlBaseBaseListener, HoodieSqlBaseLexer, HoodieSqlBaseParser} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.parser.{ParseErrorListener, ParseException, ParserInterface} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.internal.VariableSubstitution +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{AnalysisException, SparkSession} + +import java.util.Locale + +class HoodieSpark3_5ExtendedSqlParser(session: SparkSession, delegate: ParserInterface) + extends HoodieExtendedParserInterface with Logging { + + private lazy val conf = session.sqlContext.conf + private lazy val builder = new HoodieSpark3_5ExtendedSqlAstBuilder(conf, delegate) + private val substitutor = new VariableSubstitution + + override def parsePlan(sqlText: String): LogicalPlan = { + val substitutionSql = substitutor.substitute(sqlText) + if (isHoodieCommand(substitutionSql)) { + parse(substitutionSql) { parser => + builder.visit(parser.singleStatement()) match { + case plan: LogicalPlan => plan + case _ => delegate.parsePlan(sqlText) + } + } + } else { + delegate.parsePlan(substitutionSql) + } + } + + override def parseQuery(sqlText: String): LogicalPlan = delegate.parseQuery(sqlText) + + override def parseExpression(sqlText: String): Expression = delegate.parseExpression(sqlText) + + override def parseTableIdentifier(sqlText: String): TableIdentifier = + delegate.parseTableIdentifier(sqlText) + + override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = + delegate.parseFunctionIdentifier(sqlText) + + override def parseTableSchema(sqlText: String): StructType = delegate.parseTableSchema(sqlText) + + override def parseDataType(sqlText: String): DataType = delegate.parseDataType(sqlText) + + protected def parse[T](command: String)(toResult: HoodieSqlBaseParser => T): T = { + logDebug(s"Parsing command: $command") + + val lexer = new HoodieSqlBaseLexer(new UpperCaseCharStream(CharStreams.fromString(command))) + lexer.removeErrorListeners() + lexer.addErrorListener(ParseErrorListener) + + val tokenStream = new CommonTokenStream(lexer) + val parser = new HoodieSqlBaseParser(tokenStream) + parser.addParseListener(PostProcessor) + parser.removeErrorListeners() + parser.addErrorListener(ParseErrorListener) + // parser.legacy_setops_precedence_enabled = conf.setOpsPrecedenceEnforced + parser.legacy_exponent_literal_as_decimal_enabled = conf.exponentLiteralAsDecimalEnabled + parser.SQL_standard_keyword_behavior = conf.ansiEnabled + + try { + try { + // first, try parsing with potentially faster SLL mode + parser.getInterpreter.setPredictionMode(PredictionMode.SLL) + toResult(parser) + } + catch { + case e: ParseCancellationException => + // if we fail, parse with LL mode + tokenStream.seek(0) // rewind input stream + parser.reset() + + // Try Again. + parser.getInterpreter.setPredictionMode(PredictionMode.LL) + toResult(parser) + } + } + catch { + case e: ParseException if e.command.isDefined => + throw e + case e: ParseException => + throw e.withCommand(command) + case e: AnalysisException => + val position = Origin(e.line, e.startPosition) + throw new ParseException(Option(command), e.message, position, position) + } + } + + override def parseMultipartIdentifier(sqlText: String): Seq[String] = { + delegate.parseMultipartIdentifier(sqlText) + } + + private def isHoodieCommand(sqlText: String): Boolean = { + val normalized = sqlText.toLowerCase(Locale.ROOT).trim().replaceAll("\\s+", " ") + normalized.contains("system_time as of") || + normalized.contains("timestamp as of") || + normalized.contains("system_version as of") || + normalized.contains("version as of") || + normalized.contains("create index") || + normalized.contains("drop index") || + normalized.contains("show indexes") || + normalized.contains("refresh index") + } +} + +/** + * Fork from `org.apache.spark.sql.catalyst.parser.UpperCaseCharStream`. + */ +class UpperCaseCharStream(wrapped: CodePointCharStream) extends CharStream { + override def consume(): Unit = wrapped.consume + override def getSourceName(): String = wrapped.getSourceName + override def index(): Int = wrapped.index + override def mark(): Int = wrapped.mark + override def release(marker: Int): Unit = wrapped.release(marker) + override def seek(where: Int): Unit = wrapped.seek(where) + override def size(): Int = wrapped.size + + override def getText(interval: Interval): String = { + // ANTLR 4.7's CodePointCharStream implementations have bugs when + // getText() is called with an empty stream, or intervals where + // the start > end. See + // https://github.com/antlr/antlr4/commit/ac9f7530 for one fix + // that is not yet in a released ANTLR artifact. + if (size() > 0 && (interval.b - interval.a >= 0)) { + wrapped.getText(interval) + } else { + "" + } + } + // scalastyle:off + override def LA(i: Int): Int = { + // scalastyle:on + val la = wrapped.LA(i) + if (la == 0 || la == IntStream.EOF) la + else Character.toUpperCase(la) + } +} + +/** + * Fork from `org.apache.spark.sql.catalyst.parser.PostProcessor`. + */ +case object PostProcessor extends HoodieSqlBaseBaseListener { + + /** Remove the back ticks from an Identifier. */ + override def exitQuotedIdentifier(ctx: QuotedIdentifierContext): Unit = { + replaceTokenByIdentifier(ctx, 1) { token => + // Remove the double back ticks in the string. + token.setText(token.getText.replace("``", "`")) + token + } + } + + /** Treat non-reserved keywords as Identifiers. */ + override def exitNonReserved(ctx: NonReservedContext): Unit = { + replaceTokenByIdentifier(ctx, 0)(identity) + } + + private def replaceTokenByIdentifier( + ctx: ParserRuleContext, + stripMargins: Int)( + f: CommonToken => CommonToken = identity): Unit = { + val parent = ctx.getParent + parent.removeLastChild() + val token = ctx.getChild(0).getPayload.asInstanceOf[Token] + val newToken = new CommonToken( + new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream), + HoodieSqlBaseParser.IDENTIFIER, + token.getChannel, + token.getStartIndex + stripMargins, + token.getStopIndex - stripMargins) + parent.addChild(new TerminalNodeImpl(f(newToken))) + } +} diff --git a/hudi-spark-datasource/hudi-spark3.5.x/src/test/java/org/apache/hudi/internal/HoodieBulkInsertInternalWriterTestBase.java b/hudi-spark-datasource/hudi-spark3.5.x/src/test/java/org/apache/hudi/internal/HoodieBulkInsertInternalWriterTestBase.java new file mode 100644 index 0000000000000..d4b0b0e764ed8 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.5.x/src/test/java/org/apache/hudi/internal/HoodieBulkInsertInternalWriterTestBase.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hudi.internal; + +import org.apache.hudi.DataSourceWriteOptions; +import org.apache.hudi.client.WriteStatus; +import org.apache.hudi.common.model.HoodieRecord; +import org.apache.hudi.common.model.HoodieRecord.HoodieMetadataField; +import org.apache.hudi.common.model.HoodieWriteStat; +import org.apache.hudi.common.table.HoodieTableConfig; +import org.apache.hudi.common.testutils.HoodieTestDataGenerator; +import org.apache.hudi.common.util.Option; +import org.apache.hudi.config.HoodieWriteConfig; +import org.apache.hudi.testutils.HoodieSparkClientTestHarness; +import org.apache.hudi.testutils.SparkDatasetTestUtils; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Base class for TestHoodieBulkInsertDataInternalWriter. + */ +public class HoodieBulkInsertInternalWriterTestBase extends HoodieSparkClientTestHarness { + + protected static final Random RANDOM = new Random(); + + @BeforeEach + public void setUp() throws Exception { + initSparkContexts(); + initPath(); + initFileSystem(); + initTestDataGenerator(); + initMetaClient(); + initTimelineService(); + } + + @AfterEach + public void tearDown() throws Exception { + cleanupResources(); + } + + protected HoodieWriteConfig getWriteConfig(boolean populateMetaFields) { + return getWriteConfig(populateMetaFields, DataSourceWriteOptions.HIVE_STYLE_PARTITIONING().defaultValue()); + } + + protected HoodieWriteConfig getWriteConfig(boolean populateMetaFields, String hiveStylePartitioningValue) { + Properties properties = new Properties(); + if (!populateMetaFields) { + properties.setProperty(DataSourceWriteOptions.RECORDKEY_FIELD().key(), SparkDatasetTestUtils.RECORD_KEY_FIELD_NAME); + properties.setProperty(DataSourceWriteOptions.PARTITIONPATH_FIELD().key(), SparkDatasetTestUtils.PARTITION_PATH_FIELD_NAME); + properties.setProperty(HoodieTableConfig.POPULATE_META_FIELDS.key(), "false"); + } + properties.setProperty(DataSourceWriteOptions.HIVE_STYLE_PARTITIONING().key(), hiveStylePartitioningValue); + return SparkDatasetTestUtils.getConfigBuilder(basePath, timelineServicePort).withProperties(properties).build(); + } + + protected void assertWriteStatuses(List writeStatuses, int batches, int size, + Option> fileAbsPaths, Option> fileNames) { + assertWriteStatuses(writeStatuses, batches, size, false, fileAbsPaths, fileNames, false); + } + + protected void assertWriteStatuses(List writeStatuses, int batches, int size, boolean areRecordsSorted, + Option> fileAbsPaths, Option> fileNames, boolean isHiveStylePartitioning) { + if (areRecordsSorted) { + assertEquals(batches, writeStatuses.size()); + } else { + assertEquals(Math.min(HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS.length, batches), writeStatuses.size()); + } + + Map sizeMap = new HashMap<>(); + if (!areRecordsSorted) { + // no of records are written per batch. Every 4th batch goes into same writeStatus. So, populating the size expected + // per write status + for (int i = 0; i < batches; i++) { + String partitionPath = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[i % 3]; + if (!sizeMap.containsKey(partitionPath)) { + sizeMap.put(partitionPath, 0L); + } + sizeMap.put(partitionPath, sizeMap.get(partitionPath) + size); + } + } + + int counter = 0; + for (WriteStatus writeStatus : writeStatuses) { + // verify write status + String actualPartitionPathFormat = isHiveStylePartitioning ? SparkDatasetTestUtils.PARTITION_PATH_FIELD_NAME + "=%s" : "%s"; + assertEquals(String.format(actualPartitionPathFormat, HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[counter % 3]), writeStatus.getPartitionPath()); + if (areRecordsSorted) { + assertEquals(writeStatus.getTotalRecords(), size); + } else { + assertEquals(writeStatus.getTotalRecords(), sizeMap.get(HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[counter % 3])); + } + assertNull(writeStatus.getGlobalError()); + assertEquals(writeStatus.getTotalErrorRecords(), 0); + assertEquals(writeStatus.getTotalErrorRecords(), 0); + assertFalse(writeStatus.hasErrors()); + assertNotNull(writeStatus.getFileId()); + String fileId = writeStatus.getFileId(); + if (fileAbsPaths.isPresent()) { + fileAbsPaths.get().add(basePath + "/" + writeStatus.getStat().getPath()); + } + if (fileNames.isPresent()) { + fileNames.get().add(writeStatus.getStat().getPath() + .substring(writeStatus.getStat().getPath().lastIndexOf('/') + 1)); + } + HoodieWriteStat writeStat = writeStatus.getStat(); + if (areRecordsSorted) { + assertEquals(size, writeStat.getNumInserts()); + assertEquals(size, writeStat.getNumWrites()); + } else { + assertEquals(sizeMap.get(HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[counter % 3]), writeStat.getNumInserts()); + assertEquals(sizeMap.get(HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[counter % 3]), writeStat.getNumWrites()); + } + assertEquals(fileId, writeStat.getFileId()); + assertEquals(String.format(actualPartitionPathFormat, HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[counter++ % 3]), writeStat.getPartitionPath()); + assertEquals(0, writeStat.getNumDeletes()); + assertEquals(0, writeStat.getNumUpdateWrites()); + assertEquals(0, writeStat.getTotalWriteErrors()); + } + } + + protected void assertOutput(Dataset expectedRows, Dataset actualRows, String instantTime, Option> fileNames, + boolean populateMetaColumns) { + if (populateMetaColumns) { + // verify 3 meta fields that are filled in within create handle + actualRows.collectAsList().forEach(entry -> { + assertEquals(entry.get(HoodieMetadataField.COMMIT_TIME_METADATA_FIELD.ordinal()).toString(), instantTime); + assertFalse(entry.isNullAt(HoodieMetadataField.FILENAME_METADATA_FIELD.ordinal())); + if (fileNames.isPresent()) { + assertTrue(fileNames.get().contains(entry.get(HoodieMetadataField.FILENAME_METADATA_FIELD.ordinal()))); + } + assertFalse(entry.isNullAt(HoodieMetadataField.COMMIT_SEQNO_METADATA_FIELD.ordinal())); + }); + + // after trimming 2 of the meta fields, rest of the fields should match + Dataset trimmedExpected = expectedRows.drop(HoodieRecord.COMMIT_SEQNO_METADATA_FIELD, HoodieRecord.COMMIT_TIME_METADATA_FIELD, HoodieRecord.FILENAME_METADATA_FIELD); + Dataset trimmedActual = actualRows.drop(HoodieRecord.COMMIT_SEQNO_METADATA_FIELD, HoodieRecord.COMMIT_TIME_METADATA_FIELD, HoodieRecord.FILENAME_METADATA_FIELD); + assertEquals(0, trimmedActual.except(trimmedExpected).count()); + } else { // operation = BULK_INSERT_APPEND_ONLY + // all meta columns are untouched + assertEquals(0, expectedRows.except(actualRows).count()); + } + } +} diff --git a/hudi-spark-datasource/hudi-spark3.5.x/src/test/java/org/apache/hudi/spark3/internal/TestHoodieBulkInsertDataInternalWriter.java b/hudi-spark-datasource/hudi-spark3.5.x/src/test/java/org/apache/hudi/spark3/internal/TestHoodieBulkInsertDataInternalWriter.java new file mode 100644 index 0000000000000..96b06937504f1 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.5.x/src/test/java/org/apache/hudi/spark3/internal/TestHoodieBulkInsertDataInternalWriter.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hudi.spark3.internal; + +import org.apache.hudi.common.testutils.HoodieTestDataGenerator; +import org.apache.hudi.common.util.Option; +import org.apache.hudi.config.HoodieWriteConfig; +import org.apache.hudi.internal.HoodieBulkInsertInternalWriterTestBase; +import org.apache.hudi.table.HoodieSparkTable; +import org.apache.hudi.table.HoodieTable; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Stream; + +import static org.apache.hudi.testutils.SparkDatasetTestUtils.ENCODER; +import static org.apache.hudi.testutils.SparkDatasetTestUtils.STRUCT_TYPE; +import static org.apache.hudi.testutils.SparkDatasetTestUtils.getInternalRowWithError; +import static org.apache.hudi.testutils.SparkDatasetTestUtils.getRandomRows; +import static org.apache.hudi.testutils.SparkDatasetTestUtils.toInternalRows; +import static org.junit.jupiter.api.Assertions.fail; + +/** + * Unit tests {@link HoodieBulkInsertDataInternalWriter}. + */ +public class TestHoodieBulkInsertDataInternalWriter extends + HoodieBulkInsertInternalWriterTestBase { + + private static Stream configParams() { + Object[][] data = new Object[][] { + {true, true}, + {true, false}, + {false, true}, + {false, false} + }; + return Stream.of(data).map(Arguments::of); + } + + private static Stream bulkInsertTypeParams() { + Object[][] data = new Object[][] { + {true}, + {false} + }; + return Stream.of(data).map(Arguments::of); + } + + @ParameterizedTest + @MethodSource("configParams") + public void testDataInternalWriter(boolean sorted, boolean populateMetaFields) throws Exception { + // init config and table + HoodieWriteConfig cfg = getWriteConfig(populateMetaFields); + HoodieTable table = HoodieSparkTable.create(cfg, context, metaClient); + // execute N rounds + for (int i = 0; i < 2; i++) { + String instantTime = "00" + i; + // init writer + HoodieBulkInsertDataInternalWriter writer = new HoodieBulkInsertDataInternalWriter(table, cfg, instantTime, RANDOM.nextInt(100000), + RANDOM.nextLong(), STRUCT_TYPE, populateMetaFields, sorted); + + int size = 10 + RANDOM.nextInt(1000); + // write N rows to partition1, N rows to partition2 and N rows to partition3 ... Each batch should create a new RowCreateHandle and a new file + int batches = 3; + Dataset totalInputRows = null; + + for (int j = 0; j < batches; j++) { + String partitionPath = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[j % 3]; + Dataset inputRows = getRandomRows(sqlContext, size, partitionPath, false); + writeRows(inputRows, writer); + if (totalInputRows == null) { + totalInputRows = inputRows; + } else { + totalInputRows = totalInputRows.union(inputRows); + } + } + + HoodieWriterCommitMessage commitMetadata = (HoodieWriterCommitMessage) writer.commit(); + Option> fileAbsPaths = Option.of(new ArrayList<>()); + Option> fileNames = Option.of(new ArrayList<>()); + + // verify write statuses + assertWriteStatuses(commitMetadata.getWriteStatuses(), batches, size, sorted, fileAbsPaths, fileNames, false); + + // verify rows + Dataset result = sqlContext.read().parquet(fileAbsPaths.get().toArray(new String[0])); + assertOutput(totalInputRows, result, instantTime, fileNames, populateMetaFields); + } + } + + + /** + * Issue some corrupted or wrong schematized InternalRow after few valid InternalRows so that global error is thrown. write batch 1 of valid records write batch2 of invalid records which is expected + * to throw Global Error. Verify global error is set appropriately and only first batch of records are written to disk. + */ + @Test + public void testGlobalFailure() throws Exception { + // init config and table + HoodieWriteConfig cfg = getWriteConfig(true); + HoodieTable table = HoodieSparkTable.create(cfg, context, metaClient); + String partitionPath = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[0]; + + String instantTime = "001"; + HoodieBulkInsertDataInternalWriter writer = new HoodieBulkInsertDataInternalWriter(table, cfg, instantTime, RANDOM.nextInt(100000), + RANDOM.nextLong(), STRUCT_TYPE, true, false); + + int size = 10 + RANDOM.nextInt(100); + int totalFailures = 5; + // Generate first batch of valid rows + Dataset inputRows = getRandomRows(sqlContext, size / 2, partitionPath, false); + List internalRows = toInternalRows(inputRows, ENCODER); + + // generate some failures rows + for (int i = 0; i < totalFailures; i++) { + internalRows.add(getInternalRowWithError(partitionPath)); + } + + // generate 2nd batch of valid rows + Dataset inputRows2 = getRandomRows(sqlContext, size / 2, partitionPath, false); + internalRows.addAll(toInternalRows(inputRows2, ENCODER)); + + // issue writes + try { + for (InternalRow internalRow : internalRows) { + writer.write(internalRow); + } + fail("Should have failed"); + } catch (Throwable e) { + // expected + } + + HoodieWriterCommitMessage commitMetadata = (HoodieWriterCommitMessage) writer.commit(); + + Option> fileAbsPaths = Option.of(new ArrayList<>()); + Option> fileNames = Option.of(new ArrayList<>()); + // verify write statuses + assertWriteStatuses(commitMetadata.getWriteStatuses(), 1, size / 2, fileAbsPaths, fileNames); + + // verify rows + Dataset result = sqlContext.read().parquet(fileAbsPaths.get().toArray(new String[0])); + assertOutput(inputRows, result, instantTime, fileNames, true); + } + + private void writeRows(Dataset inputRows, HoodieBulkInsertDataInternalWriter writer) + throws Exception { + List internalRows = toInternalRows(inputRows, ENCODER); + // issue writes + for (InternalRow internalRow : internalRows) { + writer.write(internalRow); + } + } +} diff --git a/hudi-spark-datasource/hudi-spark3.5.x/src/test/java/org/apache/hudi/spark3/internal/TestHoodieDataSourceInternalBatchWrite.java b/hudi-spark-datasource/hudi-spark3.5.x/src/test/java/org/apache/hudi/spark3/internal/TestHoodieDataSourceInternalBatchWrite.java new file mode 100644 index 0000000000000..176b67bbe98f4 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.5.x/src/test/java/org/apache/hudi/spark3/internal/TestHoodieDataSourceInternalBatchWrite.java @@ -0,0 +1,330 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hudi.spark3.internal; + +import org.apache.hudi.DataSourceWriteOptions; +import org.apache.hudi.common.model.HoodieCommitMetadata; +import org.apache.hudi.common.testutils.HoodieTestDataGenerator; +import org.apache.hudi.common.util.Option; +import org.apache.hudi.config.HoodieWriteConfig; +import org.apache.hudi.internal.HoodieBulkInsertInternalWriterTestBase; +import org.apache.hudi.table.HoodieSparkTable; +import org.apache.hudi.table.HoodieTable; +import org.apache.hudi.testutils.HoodieClientTestUtils; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.connector.write.DataWriter; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; + +import static org.apache.hudi.testutils.SparkDatasetTestUtils.ENCODER; +import static org.apache.hudi.testutils.SparkDatasetTestUtils.STRUCT_TYPE; +import static org.apache.hudi.testutils.SparkDatasetTestUtils.getRandomRows; +import static org.apache.hudi.testutils.SparkDatasetTestUtils.toInternalRows; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Unit tests {@link HoodieDataSourceInternalBatchWrite}. + */ +public class TestHoodieDataSourceInternalBatchWrite extends + HoodieBulkInsertInternalWriterTestBase { + + private static Stream bulkInsertTypeParams() { + Object[][] data = new Object[][] { + {true}, + {false} + }; + return Stream.of(data).map(Arguments::of); + } + + @ParameterizedTest + @MethodSource("bulkInsertTypeParams") + public void testDataSourceWriter(boolean populateMetaFields) throws Exception { + testDataSourceWriterInternal(Collections.EMPTY_MAP, Collections.EMPTY_MAP, populateMetaFields); + } + + private void testDataSourceWriterInternal(Map extraMetadata, Map expectedExtraMetadata, boolean populateMetaFields) throws Exception { + // init config and table + HoodieWriteConfig cfg = getWriteConfig(populateMetaFields); + HoodieTable table = HoodieSparkTable.create(cfg, context, metaClient); + String instantTime = "001"; + // init writer + HoodieDataSourceInternalBatchWrite dataSourceInternalBatchWrite = + new HoodieDataSourceInternalBatchWrite(instantTime, cfg, STRUCT_TYPE, sqlContext.sparkSession(), hadoopConf, extraMetadata, populateMetaFields, false); + DataWriter writer = dataSourceInternalBatchWrite.createBatchWriterFactory(null).createWriter(0, RANDOM.nextLong()); + + String[] partitionPaths = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS; + List partitionPathsAbs = new ArrayList<>(); + for (String partitionPath : partitionPaths) { + partitionPathsAbs.add(basePath + "/" + partitionPath + "/*"); + } + + int size = 10 + RANDOM.nextInt(1000); + int batches = 5; + Dataset totalInputRows = null; + + for (int j = 0; j < batches; j++) { + String partitionPath = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[j % 3]; + Dataset inputRows = getRandomRows(sqlContext, size, partitionPath, false); + writeRows(inputRows, writer); + if (totalInputRows == null) { + totalInputRows = inputRows; + } else { + totalInputRows = totalInputRows.union(inputRows); + } + } + + HoodieWriterCommitMessage commitMetadata = (HoodieWriterCommitMessage) writer.commit(); + List commitMessages = new ArrayList<>(); + commitMessages.add(commitMetadata); + dataSourceInternalBatchWrite.commit(commitMessages.toArray(new HoodieWriterCommitMessage[0])); + + metaClient.reloadActiveTimeline(); + Dataset result = HoodieClientTestUtils.read(jsc, basePath, sqlContext, metaClient.getFs(), partitionPathsAbs.toArray(new String[0])); + // verify output + assertOutput(totalInputRows, result, instantTime, Option.empty(), populateMetaFields); + assertWriteStatuses(commitMessages.get(0).getWriteStatuses(), batches, size, Option.empty(), Option.empty()); + + // verify extra metadata + Option commitMetadataOption = HoodieClientTestUtils.getCommitMetadataForLatestInstant(metaClient); + assertTrue(commitMetadataOption.isPresent()); + Map actualExtraMetadata = new HashMap<>(); + commitMetadataOption.get().getExtraMetadata().entrySet().stream().filter(entry -> + !entry.getKey().equals(HoodieCommitMetadata.SCHEMA_KEY)).forEach(entry -> actualExtraMetadata.put(entry.getKey(), entry.getValue())); + assertEquals(actualExtraMetadata, expectedExtraMetadata); + } + + @Test + public void testDataSourceWriterExtraCommitMetadata() throws Exception { + String commitExtraMetaPrefix = "commit_extra_meta_"; + Map extraMeta = new HashMap<>(); + extraMeta.put(DataSourceWriteOptions.COMMIT_METADATA_KEYPREFIX().key(), commitExtraMetaPrefix); + extraMeta.put(commitExtraMetaPrefix + "a", "valA"); + extraMeta.put(commitExtraMetaPrefix + "b", "valB"); + extraMeta.put("commit_extra_c", "valC"); // should not be part of commit extra metadata + + Map expectedMetadata = new HashMap<>(); + expectedMetadata.putAll(extraMeta); + expectedMetadata.remove(DataSourceWriteOptions.COMMIT_METADATA_KEYPREFIX().key()); + expectedMetadata.remove("commit_extra_c"); + + testDataSourceWriterInternal(extraMeta, expectedMetadata, true); + } + + @Test + public void testDataSourceWriterEmptyExtraCommitMetadata() throws Exception { + String commitExtraMetaPrefix = "commit_extra_meta_"; + Map extraMeta = new HashMap<>(); + extraMeta.put(DataSourceWriteOptions.COMMIT_METADATA_KEYPREFIX().key(), commitExtraMetaPrefix); + extraMeta.put("keyA", "valA"); + extraMeta.put("keyB", "valB"); + extraMeta.put("commit_extra_c", "valC"); + // none of the keys has commit metadata key prefix. + testDataSourceWriterInternal(extraMeta, Collections.EMPTY_MAP, true); + } + + @ParameterizedTest + @MethodSource("bulkInsertTypeParams") + public void testMultipleDataSourceWrites(boolean populateMetaFields) throws Exception { + // init config and table + HoodieWriteConfig cfg = getWriteConfig(populateMetaFields); + HoodieTable table = HoodieSparkTable.create(cfg, context, metaClient); + int partitionCounter = 0; + + // execute N rounds + for (int i = 0; i < 2; i++) { + String instantTime = "00" + i; + // init writer + HoodieDataSourceInternalBatchWrite dataSourceInternalBatchWrite = + new HoodieDataSourceInternalBatchWrite(instantTime, cfg, STRUCT_TYPE, sqlContext.sparkSession(), hadoopConf, Collections.EMPTY_MAP, populateMetaFields, false); + List commitMessages = new ArrayList<>(); + Dataset totalInputRows = null; + DataWriter writer = dataSourceInternalBatchWrite.createBatchWriterFactory(null).createWriter(partitionCounter++, RANDOM.nextLong()); + + int size = 10 + RANDOM.nextInt(1000); + int batches = 3; // one batch per partition + + for (int j = 0; j < batches; j++) { + String partitionPath = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[j % 3]; + Dataset inputRows = getRandomRows(sqlContext, size, partitionPath, false); + writeRows(inputRows, writer); + if (totalInputRows == null) { + totalInputRows = inputRows; + } else { + totalInputRows = totalInputRows.union(inputRows); + } + } + + HoodieWriterCommitMessage commitMetadata = (HoodieWriterCommitMessage) writer.commit(); + commitMessages.add(commitMetadata); + dataSourceInternalBatchWrite.commit(commitMessages.toArray(new HoodieWriterCommitMessage[0])); + metaClient.reloadActiveTimeline(); + + Dataset result = HoodieClientTestUtils.readCommit(basePath, sqlContext, metaClient.getCommitTimeline(), instantTime, populateMetaFields); + + // verify output + assertOutput(totalInputRows, result, instantTime, Option.empty(), populateMetaFields); + assertWriteStatuses(commitMessages.get(0).getWriteStatuses(), batches, size, Option.empty(), Option.empty()); + } + } + + // Large writes are not required to be executed w/ regular CI jobs. Takes lot of running time. + @Disabled + @ParameterizedTest + @MethodSource("bulkInsertTypeParams") + public void testLargeWrites(boolean populateMetaFields) throws Exception { + // init config and table + HoodieWriteConfig cfg = getWriteConfig(populateMetaFields); + HoodieTable table = HoodieSparkTable.create(cfg, context, metaClient); + int partitionCounter = 0; + + // execute N rounds + for (int i = 0; i < 3; i++) { + String instantTime = "00" + i; + // init writer + HoodieDataSourceInternalBatchWrite dataSourceInternalBatchWrite = + new HoodieDataSourceInternalBatchWrite(instantTime, cfg, STRUCT_TYPE, sqlContext.sparkSession(), hadoopConf, Collections.EMPTY_MAP, populateMetaFields, false); + List commitMessages = new ArrayList<>(); + Dataset totalInputRows = null; + DataWriter writer = dataSourceInternalBatchWrite.createBatchWriterFactory(null).createWriter(partitionCounter++, RANDOM.nextLong()); + + int size = 10000 + RANDOM.nextInt(10000); + int batches = 3; // one batch per partition + + for (int j = 0; j < batches; j++) { + String partitionPath = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[j % 3]; + Dataset inputRows = getRandomRows(sqlContext, size, partitionPath, false); + writeRows(inputRows, writer); + if (totalInputRows == null) { + totalInputRows = inputRows; + } else { + totalInputRows = totalInputRows.union(inputRows); + } + } + + HoodieWriterCommitMessage commitMetadata = (HoodieWriterCommitMessage) writer.commit(); + commitMessages.add(commitMetadata); + dataSourceInternalBatchWrite.commit(commitMessages.toArray(new HoodieWriterCommitMessage[0])); + metaClient.reloadActiveTimeline(); + + Dataset result = HoodieClientTestUtils.readCommit(basePath, sqlContext, metaClient.getCommitTimeline(), instantTime, + populateMetaFields); + + // verify output + assertOutput(totalInputRows, result, instantTime, Option.empty(), populateMetaFields); + assertWriteStatuses(commitMessages.get(0).getWriteStatuses(), batches, size, Option.empty(), Option.empty()); + } + } + + /** + * Tests that DataSourceWriter.abort() will abort the written records of interest write and commit batch1 write and abort batch2 Read of entire dataset should show only records from batch1. + * commit batch1 + * abort batch2 + * verify only records from batch1 is available to read + */ + @ParameterizedTest + @MethodSource("bulkInsertTypeParams") + public void testAbort(boolean populateMetaFields) throws Exception { + // init config and table + HoodieWriteConfig cfg = getWriteConfig(populateMetaFields); + HoodieTable table = HoodieSparkTable.create(cfg, context, metaClient); + String instantTime0 = "00" + 0; + // init writer + HoodieDataSourceInternalBatchWrite dataSourceInternalBatchWrite = + new HoodieDataSourceInternalBatchWrite(instantTime0, cfg, STRUCT_TYPE, sqlContext.sparkSession(), hadoopConf, Collections.EMPTY_MAP, populateMetaFields, false); + DataWriter writer = dataSourceInternalBatchWrite.createBatchWriterFactory(null).createWriter(0, RANDOM.nextLong()); + + List partitionPaths = Arrays.asList(HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS); + List partitionPathsAbs = new ArrayList<>(); + for (String partitionPath : partitionPaths) { + partitionPathsAbs.add(basePath + "/" + partitionPath + "/*"); + } + + int size = 10 + RANDOM.nextInt(100); + int batches = 1; + Dataset totalInputRows = null; + + for (int j = 0; j < batches; j++) { + String partitionPath = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[j % 3]; + Dataset inputRows = getRandomRows(sqlContext, size, partitionPath, false); + writeRows(inputRows, writer); + if (totalInputRows == null) { + totalInputRows = inputRows; + } else { + totalInputRows = totalInputRows.union(inputRows); + } + } + + HoodieWriterCommitMessage commitMetadata = (HoodieWriterCommitMessage) writer.commit(); + List commitMessages = new ArrayList<>(); + commitMessages.add(commitMetadata); + // commit 1st batch + dataSourceInternalBatchWrite.commit(commitMessages.toArray(new HoodieWriterCommitMessage[0])); + metaClient.reloadActiveTimeline(); + Dataset result = HoodieClientTestUtils.read(jsc, basePath, sqlContext, metaClient.getFs(), partitionPathsAbs.toArray(new String[0])); + // verify rows + assertOutput(totalInputRows, result, instantTime0, Option.empty(), populateMetaFields); + assertWriteStatuses(commitMessages.get(0).getWriteStatuses(), batches, size, Option.empty(), Option.empty()); + + // 2nd batch. abort in the end + String instantTime1 = "00" + 1; + dataSourceInternalBatchWrite = + new HoodieDataSourceInternalBatchWrite(instantTime1, cfg, STRUCT_TYPE, sqlContext.sparkSession(), hadoopConf, Collections.EMPTY_MAP, populateMetaFields, false); + writer = dataSourceInternalBatchWrite.createBatchWriterFactory(null).createWriter(1, RANDOM.nextLong()); + + for (int j = 0; j < batches; j++) { + String partitionPath = HoodieTestDataGenerator.DEFAULT_PARTITION_PATHS[j % 3]; + Dataset inputRows = getRandomRows(sqlContext, size, partitionPath, false); + writeRows(inputRows, writer); + } + + commitMetadata = (HoodieWriterCommitMessage) writer.commit(); + commitMessages = new ArrayList<>(); + commitMessages.add(commitMetadata); + // commit 1st batch + dataSourceInternalBatchWrite.abort(commitMessages.toArray(new HoodieWriterCommitMessage[0])); + metaClient.reloadActiveTimeline(); + result = HoodieClientTestUtils.read(jsc, basePath, sqlContext, metaClient.getFs(), partitionPathsAbs.toArray(new String[0])); + // verify rows + // only rows from first batch should be present + assertOutput(totalInputRows, result, instantTime0, Option.empty(), populateMetaFields); + } + + private void writeRows(Dataset inputRows, DataWriter writer) throws Exception { + List internalRows = toInternalRows(inputRows, ENCODER); + // issue writes + for (InternalRow internalRow : internalRows) { + writer.write(internalRow); + } + } +} diff --git a/hudi-spark-datasource/hudi-spark3-common/src/test/java/org/apache/hudi/spark3/internal/TestReflectUtil.java b/hudi-spark-datasource/hudi-spark3.5.x/src/test/java/org/apache/hudi/spark3/internal/TestReflectUtil.java similarity index 90% rename from hudi-spark-datasource/hudi-spark3-common/src/test/java/org/apache/hudi/spark3/internal/TestReflectUtil.java rename to hudi-spark-datasource/hudi-spark3.5.x/src/test/java/org/apache/hudi/spark3/internal/TestReflectUtil.java index 075e4242cb006..5a08e54f5e171 100644 --- a/hudi-spark-datasource/hudi-spark3-common/src/test/java/org/apache/hudi/spark3/internal/TestReflectUtil.java +++ b/hudi-spark-datasource/hudi-spark3.5.x/src/test/java/org/apache/hudi/spark3/internal/TestReflectUtil.java @@ -23,14 +23,10 @@ import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; import org.apache.spark.sql.catalyst.plans.logical.InsertIntoStatement; + import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -import java.util.Collections; - -import static scala.collection.JavaConverters.asScalaBuffer; - - /** * Unit tests {@link ReflectUtil}. */ @@ -46,10 +42,11 @@ public void testDataSourceWriterExtraCommitMetadata() throws Exception { InsertIntoStatement newStatment = ReflectUtil.createInsertInto( statement.table(), statement.partitionSpec(), - asScalaBuffer(Collections.emptyList()).toSeq(), + scala.collection.immutable.List.empty(), statement.query(), statement.overwrite(), - statement.ifPartitionNotExists()); + statement.ifPartitionNotExists(), + statement.byName()); Assertions.assertTrue( ((UnresolvedRelation)newStatment.table()).multipartIdentifier().contains("test_reflect_util")); diff --git a/packaging/bundle-validation/base/build_flink1180hive313spark350.sh b/packaging/bundle-validation/base/build_flink1180hive313spark350.sh new file mode 100755 index 0000000000000..dca3acdc5bc57 --- /dev/null +++ b/packaging/bundle-validation/base/build_flink1180hive313spark350.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +docker build \ + --build-arg HIVE_VERSION=3.1.3 \ + --build-arg FLINK_VERSION=1.18.0 \ + --build-arg SPARK_VERSION=3.5.0 \ + --build-arg SPARK_HADOOP_VERSION=3 \ + --build-arg HADOOP_VERSION=3.3.5 \ + -t hudi-ci-bundle-validation-base:flink1180hive313spark350 . +docker image tag hudi-ci-bundle-validation-base:flink1180hive313spark350 apachehudi/hudi-ci-bundle-validation-base:flink1180hive313spark350 diff --git a/packaging/bundle-validation/ci_run.sh b/packaging/bundle-validation/ci_run.sh index 505ee9c7c2d48..59fc5d9df3972 100755 --- a/packaging/bundle-validation/ci_run.sh +++ b/packaging/bundle-validation/ci_run.sh @@ -104,6 +104,16 @@ elif [[ ${SPARK_RUNTIME} == 'spark3.4.0' ]]; then CONFLUENT_VERSION=5.5.12 KAFKA_CONNECT_HDFS_VERSION=10.1.13 IMAGE_TAG=flink1170hive313spark340 +elif [[ ${SPARK_RUNTIME} == 'spark3.5.0' ]]; then + HADOOP_VERSION=3.3.5 + HIVE_VERSION=3.1.3 + DERBY_VERSION=10.14.1.0 + FLINK_VERSION=1.18.0 + SPARK_VERSION=3.5.0 + SPARK_HADOOP_VERSION=3 + CONFLUENT_VERSION=5.5.12 + KAFKA_CONNECT_HDFS_VERSION=10.1.13 + IMAGE_TAG=flink1180hive313spark350 fi # Copy bundle jars to temp dir for mounting diff --git a/packaging/bundle-validation/run_docker_java17.sh b/packaging/bundle-validation/run_docker_java17.sh index 879b56367e0c0..d9f50cc90768a 100755 --- a/packaging/bundle-validation/run_docker_java17.sh +++ b/packaging/bundle-validation/run_docker_java17.sh @@ -93,6 +93,16 @@ elif [[ ${SPARK_RUNTIME} == 'spark3.4.0' ]]; then CONFLUENT_VERSION=5.5.12 KAFKA_CONNECT_HDFS_VERSION=10.1.13 IMAGE_TAG=flink1170hive313spark340 +elif [[ ${SPARK_RUNTIME} == 'spark3.5.0' ]]; then + HADOOP_VERSION=3.3.5 + HIVE_VERSION=3.1.3 + DERBY_VERSION=10.14.1.0 + FLINK_VERSION=1.18.0 + SPARK_VERSION=3.5.0 + SPARK_HADOOP_VERSION=3 + CONFLUENT_VERSION=5.5.12 + KAFKA_CONNECT_HDFS_VERSION=10.1.13 + IMAGE_TAG=flink1180hive313spark350 fi # build docker image diff --git a/packaging/hudi-utilities-bundle/pom.xml b/packaging/hudi-utilities-bundle/pom.xml index dae88460ed07d..8516a61e589fa 100644 --- a/packaging/hudi-utilities-bundle/pom.xml +++ b/packaging/hudi-utilities-bundle/pom.xml @@ -123,6 +123,8 @@ com.github.davidmoten:guava-mini com.github.davidmoten:hilbert-curve com.github.ben-manes.caffeine:caffeine + + com.google.protobuf:protobuf-java com.twitter:bijection-avro_${scala.binary.version} com.twitter:bijection-core_${scala.binary.version} io.confluent:kafka-avro-serializer @@ -226,6 +228,10 @@ org.apache.httpcomponents. org.apache.hudi.aws.org.apache.httpcomponents. + + com.google.protobuf. + org.apache.hudi.com.google.protobuf. + org.roaringbitmap. org.apache.hudi.org.roaringbitmap. diff --git a/packaging/hudi-utilities-slim-bundle/pom.xml b/packaging/hudi-utilities-slim-bundle/pom.xml index 9e02a1e94caf6..ec7a08af85d7c 100644 --- a/packaging/hudi-utilities-slim-bundle/pom.xml +++ b/packaging/hudi-utilities-slim-bundle/pom.xml @@ -109,6 +109,8 @@ com.github.davidmoten:guava-mini com.github.davidmoten:hilbert-curve + + com.google.protobuf:protobuf-java com.twitter:bijection-avro_${scala.binary.version} com.twitter:bijection-core_${scala.binary.version} io.confluent:kafka-avro-serializer @@ -189,6 +191,10 @@ org.openjdk.jol. org.apache.hudi.org.openjdk.jol. + + com.google.protobuf. + org.apache.hudi.com.google.protobuf. + diff --git a/pom.xml b/pom.xml index 2b2bf4665e556..6fdae65db3907 100644 --- a/pom.xml +++ b/pom.xml @@ -82,7 +82,7 @@ 3.2.0 2.22.2 2.22.2 - 3.2.4 + 3.4.0 3.1.1 3.8.0 2.4 @@ -165,6 +165,7 @@ 3.2.3 3.3.1 3.4.1 + 3.5.0 hudi-spark3.2.x hudi-spark3-common hudi-spark3.2plus-common ${scalatest.spark3.version} ${kafka.spark3.version} + 2.8.1 - 1.12.3 - 1.8.3 - 1.11.1 + 1.13.1 + 1.9.1 + 1.11.2 4.9.3 - 2.14.2 + 2.15.2 ${fasterxml.spark3.version} ${fasterxml.spark3.version} ${fasterxml.spark3.version} - ${fasterxml.spark3.version} + ${fasterxml.spark3.version} + ${pulsar.spark.scala12.version} - 2.19.0 - 2.0.6 + 2.20.0 + 2.0.7 true true - hudi-spark-datasource/hudi-spark3.4.x + hudi-spark-datasource/hudi-spark3.5.x hudi-spark-datasource/hudi-spark3-common hudi-spark-datasource/hudi-spark3.2plus-common @@ -2279,6 +2282,11 @@ ${slf4j.version} test
+ + ${hive.groupid} + hive-storage-api + ${hive.storage.version} + @@ -2508,6 +2516,66 @@ + + spark3.5 + + ${spark35.version} + ${spark3.version} + 3.5 + 2.12.18 + ${scala12.version} + 2.12 + hudi-spark3.5.x + + hudi-spark3-common + hudi-spark3.2plus-common + ${scalatest.spark3.version} + ${kafka.spark3.version} + 2.8.1 + + 1.13.1 + 1.9.1 + 1.11.2 + 4.9.3 + 2.15.2 + ${fasterxml.spark3.version} + ${fasterxml.spark3.version} + ${fasterxml.spark3.version} + ${fasterxml.spark3.version} + ${pulsar.spark.scala12.version} + 2.20.0 + 2.0.7 + true + true + + + hudi-spark-datasource/hudi-spark3.5.x + hudi-spark-datasource/hudi-spark3-common + hudi-spark-datasource/hudi-spark3.2plus-common + + + + org.slf4j + slf4j-log4j12 + ${slf4j.version} + test + + + ${hive.groupid} + hive-storage-api + ${hive.storage.version} + + + + + spark3.5 + + + + flink1.18