diff --git a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/AvroConversionUtils.scala b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/AvroConversionUtils.scala index 47a90cc2f63e4..9f5d6fd7afee9 100644 --- a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/AvroConversionUtils.scala +++ b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/AvroConversionUtils.scala @@ -18,9 +18,8 @@ package org.apache.hudi -import org.apache.avro.Schema.Type import org.apache.avro.generic.GenericRecord -import org.apache.avro.{AvroRuntimeException, JsonProperties, Schema} +import org.apache.avro.{JsonProperties, Schema} import org.apache.hudi.HoodieSparkUtils.sparkAdapter import org.apache.hudi.avro.AvroSchemaUtils import org.apache.spark.rdd.RDD @@ -29,31 +28,9 @@ import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} import org.apache.spark.sql.{Dataset, Row, SparkSession} import scala.collection.JavaConversions._ -import scala.collection.JavaConverters._ object AvroConversionUtils { - /** - * Check the nullability of the input Avro type and resolve it when it is nullable. The first - * return value is a [[Boolean]] indicating if the input Avro type is nullable. The second - * return value is either provided Avro type if it's not nullable, or its resolved non-nullable part - * in case it is - */ - def resolveAvroTypeNullability(avroType: Schema): (Boolean, Schema) = { - if (avroType.getType == Type.UNION) { - val fields = avroType.getTypes.asScala - val actualType = fields.filter(_.getType != Type.NULL) - if (fields.length != 2 || actualType.length != 1) { - throw new AvroRuntimeException( - s"Unsupported Avro UNION type $avroType: Only UNION of a null type and a non-null " + - "type is supported") - } - (true, actualType.head) - } else { - (false, avroType) - } - } - /** * Creates converter to transform Avro payload into Spark's Catalyst one * @@ -104,7 +81,7 @@ object AvroConversionUtils { recordNamespace: String): Row => GenericRecord = { val serde = sparkAdapter.createSparkRowSerDe(sourceSqlType) val avroSchema = AvroConversionUtils.convertStructTypeToAvroSchema(sourceSqlType, structName, recordNamespace) - val (nullable, _) = resolveAvroTypeNullability(avroSchema) + val nullable = AvroSchemaUtils.resolveNullableSchema(avroSchema) != avroSchema val converter = AvroConversionUtils.createInternalRowToAvroConverter(sourceSqlType, avroSchema, nullable) diff --git a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala index bd7d3647b2ea2..dec0eb5805cf2 100644 --- a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala +++ b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/HoodieSparkUtils.scala @@ -21,7 +21,7 @@ package org.apache.hudi import org.apache.avro.Schema import org.apache.avro.generic.GenericRecord import org.apache.hudi.HoodieConversionUtils.toScalaOption -import org.apache.hudi.avro.HoodieAvroUtils +import org.apache.hudi.avro.{AvroSchemaUtils, HoodieAvroUtils} import org.apache.hudi.client.utils.SparkRowSerDe import org.apache.hudi.common.model.HoodieRecord import org.apache.spark.SPARK_VERSION @@ -84,7 +84,7 @@ object HoodieSparkUtils extends SparkAdapterSupport with SparkVersionsSupport { // making Spark deserialize its internal representation [[InternalRow]] into [[Row]] for subsequent conversion // (and back) val sameSchema = writerAvroSchema.equals(readerAvroSchema) - val (nullable, _) = AvroConversionUtils.resolveAvroTypeNullability(writerAvroSchema) + val nullable = AvroSchemaUtils.resolveNullableSchema(writerAvroSchema) != writerAvroSchema // NOTE: We have to serialize Avro schema, and then subsequently parse it on the executor node, since Spark // serializer is not able to digest it diff --git a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/util/JFunction.scala b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/util/JFunction.scala index 3517d6414483a..1102a7230572b 100644 --- a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/util/JFunction.scala +++ b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/util/JFunction.scala @@ -27,6 +27,8 @@ import scala.language.implicitConversions */ object JFunction { + def scalaFunction1Noop[T]: T => Unit = _ => {} + //////////////////////////////////////////////////////////// // From Java to Scala //////////////////////////////////////////////////////////// diff --git a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieCatalystExpressionUtils.scala b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieCatalystExpressionUtils.scala index cdd3bfd8179b4..a83afd514f1c3 100644 --- a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieCatalystExpressionUtils.scala +++ b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieCatalystExpressionUtils.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql +import org.apache.hudi.SparkAdapterSupport import org.apache.hudi.SparkAdapterSupport.sparkAdapter import org.apache.hudi.common.util.ValidationUtils.checkState import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction} +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateMutableProjection, GenerateUnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeEq, AttributeReference, Cast, Expression, Like, Literal, MutableProjection, SubqueryExpression, UnsafeProjection} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, CreateStruct, Expression, GetStructField, Like, Literal, Projection, SubqueryExpression, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.internal.SQLConf @@ -44,6 +47,9 @@ trait HoodieCatalystExpressionUtils { */ def normalizeExprs(exprs: Seq[Expression], attributes: Seq[Attribute]): Seq[Expression] + // TODO scala-doc + def matchCast(expr: Expression): Option[(Expression, DataType, Option[String])] + /** * Matches an expression iff * @@ -75,7 +81,7 @@ trait HoodieCatalystExpressionUtils { def unapplyCastExpression(expr: Expression): Option[(Expression, DataType, Option[String], Boolean)] } -object HoodieCatalystExpressionUtils { +object HoodieCatalystExpressionUtils extends SparkAdapterSupport { /** * Convenience extractor allowing to untuple [[Cast]] across Spark versions @@ -85,6 +91,12 @@ object HoodieCatalystExpressionUtils { sparkAdapter.getCatalystExpressionUtils.unapplyCastExpression(expr) } + /** + * Leverages [[AttributeEquals]] predicate on 2 provided [[Attribute]]s + */ + def attributeEquals(one: Attribute, other: Attribute): Boolean = + new AttributeEq(one).equals(new AttributeEq(other)) + /** * Generates instance of [[UnsafeProjection]] projecting row of one [[StructType]] into another [[StructType]] * diff --git a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieCatalystPlansUtils.scala b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieCatalystPlansUtils.scala index efd0eacac7329..b4d1fe24c7d9e 100644 --- a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieCatalystPlansUtils.scala +++ b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/HoodieCatalystPlansUtils.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan} -import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} import org.apache.spark.sql.internal.SQLConf trait HoodieCatalystPlansUtils { @@ -48,47 +47,21 @@ trait HoodieCatalystPlansUtils { */ def createExplainCommand(plan: LogicalPlan, extended: Boolean): LogicalPlan - /** - * Convert a AliasIdentifier to TableIdentifier. - */ - def toTableIdentifier(aliasId: AliasIdentifier): TableIdentifier - - /** - * Convert a UnresolvedRelation to TableIdentifier. - */ - def toTableIdentifier(relation: UnresolvedRelation): TableIdentifier - /** * Create Join logical plan. */ def createJoin(left: LogicalPlan, right: LogicalPlan, joinType: JoinType): Join /** - * Test if the logical plan is a Insert Into LogicalPlan. - */ - def isInsertInto(plan: LogicalPlan): Boolean - - /** - * Get the member of the Insert Into LogicalPlan. - */ - def getInsertIntoChildren(plan: LogicalPlan): - Option[(LogicalPlan, Map[String, Option[String]], LogicalPlan, Boolean, Boolean)] - - /** - * if the logical plan is a TimeTravelRelation LogicalPlan. - */ - def isRelationTimeTravel(plan: LogicalPlan): Boolean - - /** - * Get the member of the TimeTravelRelation LogicalPlan. + * Decomposes [[InsertIntoStatement]] into its arguments allowing to accommodate for API + * changes in Spark 3.3 */ - def getRelationTimeTravel(plan: LogicalPlan): Option[(LogicalPlan, Option[Expression], Option[String])] + def unapplyInsertIntoStatement(plan: LogicalPlan): Option[(LogicalPlan, Map[String, Option[String]], LogicalPlan, Boolean, Boolean)] /** - * Create a Insert Into LogicalPlan. + * Rebases instance of {@code InsertIntoStatement} onto provided instance of {@code targetTable} and {@code query} */ - def createInsertInto(table: LogicalPlan, partition: Map[String, Option[String]], - query: LogicalPlan, overwrite: Boolean, ifPartitionNotExists: Boolean): LogicalPlan + def rebaseInsertIntoStatement(iis: LogicalPlan, targetTable: LogicalPlan, query: LogicalPlan): LogicalPlan /** * Test if the logical plan is a Repair Table LogicalPlan. @@ -98,6 +71,5 @@ trait HoodieCatalystPlansUtils { /** * Get the member of the Repair Table LogicalPlan. */ - def getRepairTableChildren(plan: LogicalPlan): - Option[(TableIdentifier, Boolean, Boolean, String)] + def getRepairTableChildren(plan: LogicalPlan): Option[(TableIdentifier, Boolean, Boolean, String)] } diff --git a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeEq.scala b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeEq.scala new file mode 100644 index 0000000000000..efb6e39112285 --- /dev/null +++ b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeEq.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +/** + * This class primarily serves as a proxy for [[AttributeEquals]] inaccessible outside + * the current package + */ +class AttributeEq(attr: Attribute) extends AttributeEquals(attr) {} diff --git a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/hudi/SparkAdapter.scala b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/hudi/SparkAdapter.scala index 5853b4eb8a8cb..4775af504bc3b 100644 --- a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/hudi/SparkAdapter.scala +++ b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/hudi/SparkAdapter.scala @@ -25,15 +25,21 @@ import org.apache.hudi.common.table.HoodieTableMetaClient import org.apache.hudi.common.util.TablePathUtils import org.apache.spark.sql._ import org.apache.spark.sql.avro.{HoodieAvroDeserializer, HoodieAvroSchemaConverters, HoodieAvroSerializer} +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, InterpretedPredicate} import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan, SubqueryAlias} import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +import org.apache.spark.sql.parser.HoodieExtendedParserInterface import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.BaseRelation -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.{DataType, Metadata, StructType} import org.apache.spark.storage.StorageLevel import java.util.Locale @@ -48,6 +54,12 @@ trait SparkAdapter extends Serializable { */ def isColumnarBatchRow(r: InternalRow): Boolean + /** + * Creates Catalyst [[Metadata]] for Hudi's meta-fields (designating these w/ + * [[METADATA_COL_ATTR_KEY]] if available (available in Spark >= 3.2) + */ + def createCatalystMetadataForMetaField: Metadata + /** * Inject table-valued functions to SparkSessionExtensions */ @@ -96,36 +108,31 @@ trait SparkAdapter extends Serializable { /** * Create the hoodie's extended spark sql parser. */ - def createExtendedSparkParser: Option[(SparkSession, ParserInterface) => ParserInterface] = None + def createExtendedSparkParser(spark: SparkSession, delegate: ParserInterface): HoodieExtendedParserInterface /** * Create the SparkParsePartitionUtil. */ def getSparkParsePartitionUtil: SparkParsePartitionUtil - /** - * ParserInterface#parseMultipartIdentifier is supported since spark3, for spark2 this should not be called. - */ - def parseMultipartIdentifier(parser: ParserInterface, sqlText: String): Seq[String] - /** * Combine [[PartitionedFile]] to [[FilePartition]] according to `maxSplitBytes`. */ def getFilePartitions(sparkSession: SparkSession, partitionedFiles: Seq[PartitionedFile], maxSplitBytes: Long): Seq[FilePartition] - def isHoodieTable(table: LogicalPlan, spark: SparkSession): Boolean = { - unfoldSubqueryAliases(table) match { - case LogicalRelation(_, _, Some(table), _) => isHoodieTable(table) - // This is to handle the cases when table is loaded by providing - // the path to the Spark DS and not from the catalog - case LogicalRelation(fsr: HadoopFsRelation, _, _, _) => - fsr.options.get("path").map { pathStr => - val path = new Path(pathStr) - TablePathUtils.isHoodieTablePath(path.getFileSystem(spark.sparkContext.hadoopConfiguration), path) - } getOrElse(false) - - case _ => false + /** + * Checks whether [[LogicalPlan]] refers to Hudi table, and if it's the case extracts + * corresponding [[CatalogTable]] + */ + def resolveHoodieTable(plan: LogicalPlan): Option[CatalogTable] = { + EliminateSubqueryAliases(plan) match { + // First, we need to weed out unresolved plans + case plan if !plan.resolved => None + // NOTE: When resolving Hudi table we allow [[Filter]]s and [[Project]]s be applied + // on top of it + case PhysicalOperation(_, _, LogicalRelation(_, _, Some(table), _)) if isHoodieTable(table) => Some(table) + case _ => None } } @@ -142,15 +149,6 @@ trait SparkAdapter extends Serializable { isHoodieTable(table) } - protected def unfoldSubqueryAliases(plan: LogicalPlan): LogicalPlan = { - plan match { - case SubqueryAlias(_, relation: LogicalPlan) => - unfoldSubqueryAliases(relation) - case other => - other - } - } - /** * Create instance of [[ParquetFileFormat]] */ @@ -182,28 +180,12 @@ trait SparkAdapter extends Serializable { readDataSchema: StructType, metadataColumns: Seq[AttributeReference] = Seq.empty): FileScanRDD - /** - * Resolve [[DeleteFromTable]] - * SPARK-38626 condition is no longer Option in Spark 3.3 - */ - def resolveDeleteFromTable(deleteFromTable: Command, - resolveExpression: Expression => Expression): LogicalPlan - /** * Extract condition in [[DeleteFromTable]] * SPARK-38626 condition is no longer Option in Spark 3.3 */ def extractDeleteCondition(deleteFromTable: Command): Expression - /** - * Get parseQuery from ExtendedSqlParser, only for Spark 3.3+ - */ - def getQueryParserFromExtendedSqlParser(session: SparkSession, delegate: ParserInterface, - sqlText: String): LogicalPlan = { - // unsupported by default - throw new UnsupportedOperationException(s"Unsupported parseQuery method in Spark earlier than Spark 3.3.0") - } - /** * Converts instance of [[StorageLevel]] to a corresponding string */ diff --git a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/parser/HoodieExtendedParserInterface.scala b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/parser/HoodieExtendedParserInterface.scala new file mode 100644 index 0000000000000..bc282d1ac6a29 --- /dev/null +++ b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/parser/HoodieExtendedParserInterface.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parser + +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +/** + * This trait helps us to bridge compatibility gap of [[ParserInterface]] b/w different + * Spark versions + */ +trait HoodieExtendedParserInterface extends ParserInterface { + + def parseQuery(sqlText: String): LogicalPlan = { + throw new UnsupportedOperationException(s"Unsupported, parseQuery is implemented in Spark >= 3.3.0") + } + + def parseMultipartIdentifier(sqlText: String): Seq[String] = { + throw new UnsupportedOperationException(s"Unsupported, parseMultipartIdentifier is implemented in Spark >= 3.0.0") + } + +} diff --git a/hudi-common/src/main/java/org/apache/hudi/common/util/CollectionUtils.java b/hudi-common/src/main/java/org/apache/hudi/common/util/CollectionUtils.java index 3faddb91b4d83..cbda19ffec7e5 100644 --- a/hudi-common/src/main/java/org/apache/hudi/common/util/CollectionUtils.java +++ b/hudi-common/src/main/java/org/apache/hudi/common/util/CollectionUtils.java @@ -69,6 +69,26 @@ public static boolean nonEmpty(Collection c) { return !isNullOrEmpty(c); } + /** + * Reduces provided {@link Collection} using provided {@code reducer} applied to + * every element of the collection like following + * + * {@code reduce(reduce(reduce(identity, e1), e2), ...)} + * + * @param c target collection to be reduced + * @param identity element for reducing to start from + * @param reducer actual reducing operator + * + * @return result of the reduction of the collection using reducing operator + */ + public static U reduce(Collection c, U identity, BiFunction reducer) { + return c.stream() + .sequential() + .reduce(identity, reducer, (a, b) -> { + throw new UnsupportedOperationException(); + }); + } + /** * Makes a copy of provided {@link Properties} object */ diff --git a/hudi-common/src/main/java/org/apache/hudi/internal/schema/InternalSchema.java b/hudi-common/src/main/java/org/apache/hudi/internal/schema/InternalSchema.java index 229d6b66388d6..237eb95285c71 100644 --- a/hudi-common/src/main/java/org/apache/hudi/internal/schema/InternalSchema.java +++ b/hudi-common/src/main/java/org/apache/hudi/internal/schema/InternalSchema.java @@ -23,7 +23,7 @@ import org.apache.hudi.internal.schema.Types.RecordType; import java.io.Serializable; -import java.util.Arrays; +import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.List; @@ -118,7 +118,7 @@ public List getAllColsFullName() { if (nameToId == null) { nameToId = InternalSchemaBuilder.getBuilder().buildNameToId(record); } - return Arrays.asList(nameToId.keySet().toArray(new String[0])); + return new ArrayList<>(nameToId.keySet()); } /** @@ -241,15 +241,24 @@ public Field findField(String name) { } /** - * Whether colName exists in current Schema. - * Case insensitive. + * Whether {@code colName} exists in the current Schema * - * @param colName a colName - * @return Whether colName exists in current Schema + * @param colName a column name + * @param caseSensitive whether columns names should be treated as case-sensitive + * @return whether schema contains column identified by {@code colName} */ - public boolean findDuplicateCol(String colName) { - return idToName.entrySet().stream().map(e -> e.getValue().toLowerCase(Locale.ROOT)) - .collect(Collectors.toSet()).contains(colName); + public boolean hasColumn(String colName, boolean caseSensitive) { + if (caseSensitive) { + // In case we do a case-sensitive check we just need to validate whether + // schema contains field-name as it is + return idToName.containsValue(colName); + } else { + return idToName.values() + .stream() + .map(fieldName -> fieldName.toLowerCase(Locale.ROOT)) + .collect(Collectors.toSet()) + .contains(colName.toLowerCase(Locale.ROOT)); + } } public int findIdByName(String name) { diff --git a/hudi-common/src/main/java/org/apache/hudi/internal/schema/action/TableChange.java b/hudi-common/src/main/java/org/apache/hudi/internal/schema/action/TableChange.java index 2fe7a52e616ec..dc48c8c16ba70 100644 --- a/hudi-common/src/main/java/org/apache/hudi/internal/schema/action/TableChange.java +++ b/hudi-common/src/main/java/org/apache/hudi/internal/schema/action/TableChange.java @@ -83,10 +83,16 @@ abstract class BaseColumnChange implements TableChange { protected final InternalSchema internalSchema; protected final Map id2parent; protected final Map> positionChangeMap = new HashMap<>(); + protected final boolean caseSensitive; BaseColumnChange(InternalSchema schema) { + this(schema, false); + } + + BaseColumnChange(InternalSchema schema, boolean caseSensitive) { this.internalSchema = schema; this.id2parent = InternalSchemaBuilder.getBuilder().index2Parents(schema.getRecord()); + this.caseSensitive = caseSensitive; } /** diff --git a/hudi-common/src/main/java/org/apache/hudi/internal/schema/action/TableChanges.java b/hudi-common/src/main/java/org/apache/hudi/internal/schema/action/TableChanges.java index 3142f67e8fb29..6056d51d2e809 100644 --- a/hudi-common/src/main/java/org/apache/hudi/internal/schema/action/TableChanges.java +++ b/hudi-common/src/main/java/org/apache/hudi/internal/schema/action/TableChanges.java @@ -29,7 +29,6 @@ import java.util.HashMap; import java.util.HashSet; import java.util.List; -import java.util.Locale; import java.util.Map; import java.util.Objects; import java.util.Set; @@ -44,12 +43,12 @@ public class TableChanges { public static class ColumnUpdateChange extends TableChange.BaseColumnChange { private final Map updates = new HashMap<>(); - public static ColumnUpdateChange get(InternalSchema schema) { - return new ColumnUpdateChange(schema); + private ColumnUpdateChange(InternalSchema schema) { + super(schema, false); } - private ColumnUpdateChange(InternalSchema schema) { - super(schema); + private ColumnUpdateChange(InternalSchema schema, boolean caseSensitive) { + super(schema, caseSensitive); } @Override @@ -160,8 +159,7 @@ public ColumnUpdateChange renameColumn(String name, String newName) { if (newName == null || newName.isEmpty()) { throw new IllegalArgumentException(String.format("cannot rename column: %s to empty", name)); } - // keep consisitent with hive. column names insensitive, so we check 'newName.toLowerCase(Locale.ROOT)' - if (internalSchema.findDuplicateCol(newName.toLowerCase(Locale.ROOT))) { + if (internalSchema.hasColumn(newName, caseSensitive)) { throw new IllegalArgumentException(String.format("cannot rename column: %s to a existing name", name)); } // save update info @@ -229,6 +227,14 @@ protected Integer findIdByFullName(String fullName) { throw new IllegalArgumentException(String.format("cannot find col id for given column fullName: %s", fullName)); } } + + public static ColumnUpdateChange get(InternalSchema schema) { + return new ColumnUpdateChange(schema); + } + + public static ColumnUpdateChange get(InternalSchema schema, boolean caseSensitive) { + return new ColumnUpdateChange(schema, caseSensitive); + } } /** Deal with delete columns changes for table. */ @@ -340,8 +346,7 @@ private void addColumnsInternal(String parent, String name, Type type, String do } fullName = parent + "." + name; } else { - // keep consistent with hive, column name case insensitive - if (internalSchema.findDuplicateCol(name.toLowerCase(Locale.ROOT))) { + if (internalSchema.hasColumn(name, caseSensitive)) { throw new HoodieSchemaException(String.format("cannot add column: %s which already exist", name)); } } diff --git a/hudi-common/src/main/java/org/apache/hudi/internal/schema/utils/AvroSchemaEvolutionUtils.java b/hudi-common/src/main/java/org/apache/hudi/internal/schema/utils/AvroSchemaEvolutionUtils.java index 060cd56926cd8..2dab3d009b406 100644 --- a/hudi-common/src/main/java/org/apache/hudi/internal/schema/utils/AvroSchemaEvolutionUtils.java +++ b/hudi-common/src/main/java/org/apache/hudi/internal/schema/utils/AvroSchemaEvolutionUtils.java @@ -18,16 +18,17 @@ package org.apache.hudi.internal.schema.utils; +import org.apache.avro.Schema; import org.apache.hudi.internal.schema.InternalSchema; import org.apache.hudi.internal.schema.action.TableChanges; -import org.apache.hudi.internal.schema.convert.AvroInternalSchemaConverter; - -import org.apache.avro.Schema; import java.util.List; import java.util.TreeMap; import java.util.stream.Collectors; +import static org.apache.hudi.common.util.CollectionUtils.reduce; +import static org.apache.hudi.internal.schema.convert.AvroInternalSchemaConverter.convert; + /** * Utility methods to support evolve old avro schema based on a given schema. */ @@ -50,7 +51,7 @@ public class AvroSchemaEvolutionUtils { * @return reconcile Schema */ public static InternalSchema reconcileSchema(Schema incomingSchema, InternalSchema oldTableSchema) { - InternalSchema inComingInternalSchema = AvroInternalSchemaConverter.convert(incomingSchema); + InternalSchema inComingInternalSchema = convert(incomingSchema); // check column add/missing List colNamesFromIncoming = inComingInternalSchema.getAllColsFullName(); List colNamesFromOldSchema = oldTableSchema.getAllColsFullName(); @@ -109,40 +110,39 @@ public static InternalSchema reconcileSchema(Schema incomingSchema, InternalSche } /** - * Canonical the nullability. - * Do not allow change cols Nullability field from optional to required. - * If above problem occurs, try to correct it. + * Reconciles nullability requirements b/w {@code source} and {@code target} schemas, + * by adjusting these of the {@code source} schema to be in-line with the ones of the + * {@code target} one * - * @param writeSchema writeSchema hoodie used to write data. - * @param readSchema read schema - * @return canonical Schema + * @param sourceSchema source schema that needs reconciliation + * @param targetSchema target schema that source schema will be reconciled against + * @return schema (based off {@code source} one) that has nullability constraints reconciled */ - public static Schema canonicalizeColumnNullability(Schema writeSchema, Schema readSchema) { - if (writeSchema.getFields().isEmpty() || readSchema.getFields().isEmpty()) { - return writeSchema; + public static Schema reconcileNullability(Schema sourceSchema, Schema targetSchema) { + if (sourceSchema.getFields().isEmpty() || targetSchema.getFields().isEmpty()) { + return sourceSchema; } - InternalSchema writeInternalSchema = AvroInternalSchemaConverter.convert(writeSchema); - InternalSchema readInternalSchema = AvroInternalSchemaConverter.convert(readSchema); - List colNamesWriteSchema = writeInternalSchema.getAllColsFullName(); - List colNamesFromReadSchema = readInternalSchema.getAllColsFullName(); - // try to deal with optional change. now when we use sparksql to update hudi table, - // sparksql Will change the col type from optional to required, this is a bug. - List candidateUpdateCols = colNamesWriteSchema.stream().filter(f -> { - boolean exist = colNamesFromReadSchema.contains(f); - if (exist && (writeInternalSchema.findField(f).isOptional() != readInternalSchema.findField(f).isOptional())) { - return true; - } else { - return false; - } - }).collect(Collectors.toList()); + + InternalSchema sourceInternalSchema = convert(sourceSchema); + InternalSchema targetInternalSchema = convert(targetSchema); + + List colNamesSourceSchema = sourceInternalSchema.getAllColsFullName(); + List colNamesTargetSchema = targetInternalSchema.getAllColsFullName(); + List candidateUpdateCols = colNamesSourceSchema.stream() + .filter(f -> colNamesTargetSchema.contains(f) + && sourceInternalSchema.findField(f).isOptional() != targetInternalSchema.findField(f).isOptional()) + .collect(Collectors.toList()); + if (candidateUpdateCols.isEmpty()) { - return writeSchema; + return sourceSchema; } - // try to correct all changes - TableChanges.ColumnUpdateChange updateChange = TableChanges.ColumnUpdateChange.get(writeInternalSchema); - candidateUpdateCols.stream().forEach(f -> updateChange.updateColumnNullability(f, true)); - InternalSchema updatedSchema = SchemaChangeUtils.applyTableChanges2Schema(writeInternalSchema, updateChange); - return AvroInternalSchemaConverter.convert(updatedSchema, writeSchema.getFullName()); + + // Reconcile nullability constraints (by executing phony schema change) + TableChanges.ColumnUpdateChange schemaChange = + reduce(candidateUpdateCols, TableChanges.ColumnUpdateChange.get(sourceInternalSchema), + (change, field) -> change.updateColumnNullability(field, true)); + + return convert(SchemaChangeUtils.applyTableChanges2Schema(sourceInternalSchema, schemaChange), sourceSchema.getFullName()); } } diff --git a/hudi-common/src/main/java/org/apache/hudi/internal/schema/visitor/NameToIDVisitor.java b/hudi-common/src/main/java/org/apache/hudi/internal/schema/visitor/NameToIDVisitor.java index 60737c5b1b12d..150770a336be5 100644 --- a/hudi-common/src/main/java/org/apache/hudi/internal/schema/visitor/NameToIDVisitor.java +++ b/hudi-common/src/main/java/org/apache/hudi/internal/schema/visitor/NameToIDVisitor.java @@ -34,7 +34,7 @@ * Schema visitor to produce name -> id map for internalSchema. */ public class NameToIDVisitor extends InternalSchemaVisitor> { - private final Deque fieldNames = new LinkedList<>(); + private final Deque fieldNames = new LinkedList<>(); private final Map nameToId = new HashMap<>(); @Override diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieBaseRelation.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieBaseRelation.scala index 8a730a8334bce..99b5b5c87bae5 100644 --- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieBaseRelation.scala +++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieBaseRelation.scala @@ -89,8 +89,7 @@ abstract class HoodieBaseRelation(val sqlContext: SQLContext, extends BaseRelation with FileRelation with PrunedFilteredScan - with Logging - with SparkAdapterSupport { + with Logging { type FileSplit <: HoodieFileSplit type Relation <: HoodieBaseRelation @@ -173,7 +172,22 @@ abstract class HoodieBaseRelation(val sqlContext: SQLContext, (avroSchema, internalSchemaOpt) } - protected lazy val tableStructSchema: StructType = AvroConversionUtils.convertAvroSchemaToStructType(tableAvroSchema) + protected lazy val tableStructSchema: StructType = { + val converted = AvroConversionUtils.convertAvroSchemaToStructType(tableAvroSchema) + + val resolver = sparkSession.sessionState.analyzer.resolver + val metaFieldMetadata = sparkAdapter.createCatalystMetadataForMetaField + + // NOTE: Here we annotate meta-fields with corresponding metadata such that Spark (>= 3.2) + // is able to recognize such fields as meta-fields + StructType(converted.map { field => + if (metaFieldNames.exists(metaFieldName => resolver(metaFieldName, field.name))) { + field.copy(metadata = metaFieldMetadata) + } else { + field + } + }) + } protected lazy val partitionColumns: Array[String] = tableConfig.getPartitionFields.orElse(Array.empty) @@ -609,6 +623,8 @@ abstract class HoodieBaseRelation(val sqlContext: SQLContext, object HoodieBaseRelation extends SparkAdapterSupport { + private lazy val metaFieldNames = HoodieRecord.HOODIE_META_COLUMNS.asScala.toSet + case class BaseFileReader(read: PartitionedFile => Iterator[InternalRow], val schema: StructType) { def apply(file: PartitionedFile): Iterator[InternalRow] = read.apply(file) } diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieSparkSqlWriter.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieSparkSqlWriter.scala index 9b8d499fc1390..f9738dbd3e369 100644 --- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieSparkSqlWriter.scala +++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/HoodieSparkSqlWriter.scala @@ -25,7 +25,7 @@ import org.apache.hudi.AvroConversionUtils.{convertStructTypeToAvroSchema, getAv import org.apache.hudi.DataSourceWriteOptions._ import org.apache.hudi.HoodieConversionUtils.{toProperties, toScalaOption} import org.apache.hudi.HoodieWriterUtils._ -import org.apache.hudi.avro.AvroSchemaUtils.{canProject, isCompatibleProjectionOf, isSchemaCompatible} +import org.apache.hudi.avro.AvroSchemaUtils.{canProject, isCompatibleProjectionOf, isSchemaCompatible, resolveNullableSchema} import org.apache.hudi.avro.HoodieAvroUtils import org.apache.hudi.avro.HoodieAvroUtils.removeMetadataFields import org.apache.hudi.client.common.HoodieSparkEngineContext @@ -47,6 +47,7 @@ import org.apache.hudi.hive.{HiveSyncConfigHolder, HiveSyncTool} import org.apache.hudi.internal.DataSourceInternalWriterHelper import org.apache.hudi.internal.schema.InternalSchema import org.apache.hudi.internal.schema.convert.AvroInternalSchemaConverter +import org.apache.hudi.internal.schema.utils.AvroSchemaEvolutionUtils.reconcileNullability import org.apache.hudi.internal.schema.utils.{AvroSchemaEvolutionUtils, SerDeHelper} import org.apache.hudi.keygen.factory.HoodieSparkKeyGeneratorFactory import org.apache.hudi.keygen.{SparkKeyGeneratorInterface, TimestampBasedAvroKeyGenerator, TimestampBasedKeyGenerator} @@ -419,7 +420,7 @@ object HoodieSparkSqlWriter { SQL_MERGE_INTO_WRITES.defaultValue.toString).toBoolean val canonicalizedSourceSchema = if (shouldCanonicalizeNullable) { - AvroSchemaEvolutionUtils.canonicalizeColumnNullability(sourceSchema, latestTableSchema) + canonicalizeSchema(sourceSchema, latestTableSchema) } else { sourceSchema } @@ -597,6 +598,21 @@ object HoodieSparkSqlWriter { } } + /** + * Canonicalizes [[sourceSchema]] by reconciling it w/ [[latestTableSchema]] in following + * + *
    + *
  1. Nullability: making sure that nullability of the fields in the source schema is matching + * that of the latest table's ones
  2. + *
+ * + * TODO support casing reconciliation + */ + private def canonicalizeSchema(sourceSchema: Schema, latestTableSchema: Schema): Schema = { + reconcileNullability(sourceSchema, latestTableSchema) + } + + /** * get latest internalSchema from table * @@ -743,7 +759,7 @@ object HoodieSparkSqlWriter { def validateSchemaForHoodieIsDeleted(schema: Schema): Unit = { if (schema.getField(HoodieRecord.HOODIE_IS_DELETED_FIELD) != null && - AvroConversionUtils.resolveAvroTypeNullability(schema.getField(HoodieRecord.HOODIE_IS_DELETED_FIELD).schema())._2.getType != Schema.Type.BOOLEAN) { + resolveNullableSchema(schema.getField(HoodieRecord.HOODIE_IS_DELETED_FIELD).schema()).getType != Schema.Type.BOOLEAN) { throw new HoodieException(HoodieRecord.HOODIE_IS_DELETED_FIELD + " has to be BOOLEAN type. Passed in dataframe's schema has type " + schema.getField(HoodieRecord.HOODIE_IS_DELETED_FIELD).schema().getType) } diff --git a/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/HoodieUnaryLikeSham.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/HoodieUnaryLikeSham.scala similarity index 94% rename from hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/HoodieUnaryLikeSham.scala rename to hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/HoodieUnaryLikeSham.scala index e64709e7d83e5..5adf2bf9f8ac7 100644 --- a/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/HoodieUnaryLikeSham.scala +++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/HoodieUnaryLikeSham.scala @@ -25,7 +25,8 @@ import org.apache.spark.sql.catalyst.trees.TreeNode * This is required just to be able to compile the code that relies on [[UnaryLike]] * (introduced in Spark 3.2) against Spark < 3.2 */ -trait HoodieUnaryLikeSham[T <: TreeNode[T]] { - self: TreeNode[T] => +trait HoodieUnaryLikeSham[T <: TreeNode[T]] { self: TreeNode[T] => + protected def withNewChildInternal(newChild: T): T + } diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/HoodieSqlCommonUtils.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/HoodieSqlCommonUtils.scala index eb26ef52d34d8..8e589abbc18b3 100644 --- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/HoodieSqlCommonUtils.scala +++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/HoodieSqlCommonUtils.scala @@ -51,13 +51,6 @@ object HoodieSqlCommonUtils extends SparkAdapterSupport { override def get() = new SimpleDateFormat("yyyy-MM-dd") }) - def getTableIdentifier(table: LogicalPlan): TableIdentifier = { - table match { - case SubqueryAlias(name, _) => sparkAdapter.getCatalystPlanUtils.toTableIdentifier(name) - case _ => throw new IllegalArgumentException(s"Illegal table: $table") - } - } - def getTableSqlSchema(metaClient: HoodieTableMetaClient, includeMetadataFields: Boolean = false): Option[StructType] = { val schemaResolver = new TableSchemaResolver(metaClient) @@ -130,15 +123,6 @@ object HoodieSqlCommonUtils extends SparkAdapterSupport { } } - private def tripAlias(plan: LogicalPlan): LogicalPlan = { - plan match { - case SubqueryAlias(_, relation: LogicalPlan) => - tripAlias(relation) - case other => - other - } - } - /** * Add the hoodie meta fields to the schema. * @param schema @@ -167,18 +151,7 @@ object HoodieSqlCommonUtils extends SparkAdapterSupport { metaFields.contains(name) } - def removeMetaFields(df: DataFrame): DataFrame = { - val withoutMetaColumns = df.logicalPlan.output - .filterNot(attr => isMetaField(attr.name)) - .map(new Column(_)) - if (withoutMetaColumns.length != df.logicalPlan.output.size) { - df.select(withoutMetaColumns: _*) - } else { - df - } - } - - def removeMetaFields(attrs: Seq[Attribute]): Seq[Attribute] = { + def removeMetaFields[T <: Attribute](attrs: Seq[T]): Seq[T] = { attrs.filterNot(attr => isMetaField(attr.name)) } @@ -244,19 +217,6 @@ object HoodieSqlCommonUtils extends SparkAdapterSupport { fs.exists(metaPath) } - /** - * Split the expression to a sub expression seq by the AND operation. - * @param expression - * @return - */ - def splitByAnd(expression: Expression): Seq[Expression] = { - expression match { - case And(left, right) => - splitByAnd(left) ++ splitByAnd(right) - case exp => Seq(exp) - } - } - /** * Append the spark config and table options to the baseConfig. */ @@ -336,10 +296,10 @@ object HoodieSqlCommonUtils extends SparkAdapterSupport { resolver(field.name, other.name) && field.dataType == other.dataType } - def castIfNeeded(child: Expression, dataType: DataType, conf: SQLConf): Expression = { + def castIfNeeded(child: Expression, dataType: DataType): Expression = { child match { case Literal(nul, NullType) => Literal(nul, dataType) - case expr if child.dataType != dataType => Cast(expr, dataType, Option(conf.sessionLocalTimeZone)) + case expr if child.dataType != dataType => Cast(expr, dataType, Option(SQLConf.get.sessionLocalTimeZone)) case _ => child } } diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/command/HoodieLeafRunnableCommand.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/command/HoodieLeafRunnableCommand.scala index 47e884e962d4b..1aa8efb18d569 100644 --- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/command/HoodieLeafRunnableCommand.scala +++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/command/HoodieLeafRunnableCommand.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.hudi.command -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.catalyst.trees.HoodieLeafLike import org.apache.spark.sql.execution.command.RunnableCommand +import org.apache.spark.sql.hudi.HoodieSqlCommonUtils.isMetaField /** * Similar to `LeafRunnableCommand` in Spark3.2, `HoodieLeafRunnableCommand` mixed in @@ -27,3 +28,16 @@ import org.apache.spark.sql.execution.command.RunnableCommand * the `withNewChildrenInternal` method repeatedly. */ trait HoodieLeafRunnableCommand extends RunnableCommand with HoodieLeafLike[LogicalPlan] + +object HoodieLeafRunnableCommand { + + private[hudi] def stripMetaFieldAttributes(query: LogicalPlan): LogicalPlan = { + val filteredOutput = query.output.filterNot(attr => isMetaField(attr.name)) + if (filteredOutput == query.output) { + query + } else { + Project(filteredOutput, query) + } + } + +} diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/HoodieSqlUtils.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/HoodieSqlUtils.scala deleted file mode 100644 index 9a031e9200472..0000000000000 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/HoodieSqlUtils.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hudi - -import org.apache.hudi.SparkAdapterSupport -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.expressions.{And, Expression} -import org.apache.spark.sql.catalyst.plans.logical.{MergeIntoTable, SubqueryAlias} - -object HoodieSqlUtils extends SparkAdapterSupport { - - /** - * Get the TableIdentifier of the target table in MergeInto. - */ - def getMergeIntoTargetTableId(mergeInto: MergeIntoTable): TableIdentifier = { - val aliaId = mergeInto.targetTable match { - case SubqueryAlias(_, SubqueryAlias(tableId, _)) => tableId - case SubqueryAlias(tableId, _) => tableId - case plan => throw new IllegalArgumentException(s"Illegal plan $plan in target") - } - sparkAdapter.getCatalystPlanUtils.toTableIdentifier(aliaId) - } - - /** - * Split the expression to a sub expression seq by the AND operation. - * @param expression - * @return - */ - def splitByAnd(expression: Expression): Seq[Expression] = { - expression match { - case And(left, right) => - splitByAnd(left) ++ splitByAnd(right) - case exp => Seq(exp) - } - } -} diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala index 4875892b0efc2..1470133df1743 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala @@ -17,82 +17,101 @@ package org.apache.spark.sql.hudi.analysis -import org.apache.hudi.DataSourceWriteOptions.MOR_TABLE_TYPE_OPT_VAL -import org.apache.hudi.common.model.HoodieRecord import org.apache.hudi.common.util.ReflectionUtils import org.apache.hudi.common.util.ReflectionUtils.loadClass -import org.apache.hudi.{DataSourceReadOptions, HoodieSparkUtils, SparkAdapterSupport} -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.catalog.{CatalogUtils, HoodieCatalogTable} -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, GenericInternalRow, Literal, NamedExpression} -import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.hudi.{HoodieSparkUtils, SparkAdapterSupport} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.catalog.CatalogTable +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSeq, GenericInternalRow} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.command._ -import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation} -import org.apache.spark.sql.hudi.HoodieSqlCommonUtils.{getTableIdentifier, removeMetaFields} -import org.apache.spark.sql.hudi.HoodieSqlUtils._ +import org.apache.spark.sql.execution.datasources.{CreateTable, LogicalRelation} +import org.apache.spark.sql.hudi.HoodieSqlCommonUtils.{isMetaField, removeMetaFields} +import org.apache.spark.sql.hudi.analysis.HoodieAnalysis.{MatchInsertIntoStatement, ResolvesToHudiTable, sparkAdapter} import org.apache.spark.sql.hudi.command._ import org.apache.spark.sql.hudi.command.procedures.{HoodieProcedures, Procedure, ProcedureArgs} -import org.apache.spark.sql.hudi.{HoodieOptionConfig, HoodieSqlCommonUtils} -import org.apache.spark.sql.types.StringType import org.apache.spark.sql.{AnalysisException, SparkSession} import java.util -import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer -object HoodieAnalysis { +object HoodieAnalysis extends SparkAdapterSupport { type RuleBuilder = SparkSession => Rule[LogicalPlan] def customResolutionRules: Seq[RuleBuilder] = { - val rules: ListBuffer[RuleBuilder] = ListBuffer( - // Default rules - session => HoodieResolveReferences(session), - session => HoodieAnalysis(session) - ) + val rules: ListBuffer[RuleBuilder] = ListBuffer() + + // NOTE: This rule adjusts [[LogicalRelation]]s resolving into Hudi tables such that + // meta-fields are not affecting the resolution of the target columns to be updated by Spark. + // For more details please check out the scala-doc of the rule + // TODO limit adapters to only Spark < 3.2 + val adaptIngestionTargetLogicalRelations: RuleBuilder = session => AdaptIngestionTargetLogicalRelations(session) + + if (HoodieSparkUtils.isSpark2) { + val spark2ResolveReferencesClass = "org.apache.spark.sql.catalyst.analysis.HoodieSpark2Analysis$ResolveReferences" + val spark2ResolveReferences: RuleBuilder = + session => ReflectionUtils.loadClass(spark2ResolveReferencesClass, session).asInstanceOf[Rule[LogicalPlan]] + + // TODO elaborate on the ordering + rules += (adaptIngestionTargetLogicalRelations, spark2ResolveReferences) + } else { + rules += adaptIngestionTargetLogicalRelations + } if (HoodieSparkUtils.gteqSpark3_2) { val dataSourceV2ToV1FallbackClass = "org.apache.spark.sql.hudi.analysis.HoodieDataSourceV2ToV1Fallback" val dataSourceV2ToV1Fallback: RuleBuilder = session => instantiateKlass(dataSourceV2ToV1FallbackClass, session) - val spark3AnalysisClass = "org.apache.spark.sql.hudi.analysis.HoodieSpark3Analysis" - val spark3Analysis: RuleBuilder = - session => instantiateKlass(spark3AnalysisClass, session) + val spark32PlusResolveReferencesClass = "org.apache.spark.sql.hudi.analysis.HoodieSpark32PlusResolveReferences" + val spark32PlusResolveReferences: RuleBuilder = + session => instantiateKlass(spark32PlusResolveReferencesClass, session) + + // NOTE: PLEASE READ CAREFULLY BEFORE CHANGING + // + // It's critical for this rules to follow in this order; re-ordering this rules might lead to changes in + // behavior of Spark's analysis phase (for ex, DataSource V2 to V1 fallback might not kick in before other rules, + // leading to all relations resolving as V2 instead of current expectation of them being resolved as V1) + rules ++= Seq(dataSourceV2ToV1Fallback, spark32PlusResolveReferences) + } + if (HoodieSparkUtils.isSpark3) { val resolveAlterTableCommandsClass = - if (HoodieSparkUtils.gteqSpark3_3) + if (HoodieSparkUtils.gteqSpark3_3) { "org.apache.spark.sql.hudi.Spark33ResolveHudiAlterTableCommand" - else "org.apache.spark.sql.hudi.Spark32ResolveHudiAlterTableCommand" + } else if (HoodieSparkUtils.gteqSpark3_2) { + "org.apache.spark.sql.hudi.Spark32ResolveHudiAlterTableCommand" + } else if (HoodieSparkUtils.gteqSpark3_1) { + "org.apache.spark.sql.hudi.Spark31ResolveHudiAlterTableCommand" + } else { + throw new IllegalStateException("Unsupported Spark version") + } + val resolveAlterTableCommands: RuleBuilder = session => instantiateKlass(resolveAlterTableCommandsClass, session) - // NOTE: PLEASE READ CAREFULLY - // - // It's critical for this rules to follow in this order, so that DataSource V2 to V1 fallback - // is performed prior to other rules being evaluated - rules ++= Seq(dataSourceV2ToV1Fallback, spark3Analysis, resolveAlterTableCommands) - - } else if (HoodieSparkUtils.gteqSpark3_1) { - val spark31ResolveAlterTableCommandsClass = "org.apache.spark.sql.hudi.Spark31ResolveHudiAlterTableCommand" - val spark31ResolveAlterTableCommands: RuleBuilder = - session => instantiateKlass(spark31ResolveAlterTableCommandsClass, session) - - rules ++= Seq(spark31ResolveAlterTableCommands) + rules += resolveAlterTableCommands } + // NOTE: Some of the conversions (for [[CreateTable]], [[InsertIntoStatement]] have to happen + // early to preempt execution of [[DataSourceAnalysis]] rule from Spark + // Please check rule's scala-doc for more details + rules += (_ => ResolveImplementationsEarly()) + rules } def customPostHocResolutionRules: Seq[RuleBuilder] = { val rules: ListBuffer[RuleBuilder] = ListBuffer( - // Default rules + // NOTE: By default all commands are converted into corresponding Hudi implementations during + // "post-hoc resolution" phase + session => ResolveImplementations(), session => HoodiePostAnalysisRule(session) ) if (HoodieSparkUtils.gteqSpark3_2) { - val spark3PostHocResolutionClass = "org.apache.spark.sql.hudi.analysis.HoodieSpark3PostAnalysisRule" + val spark3PostHocResolutionClass = "org.apache.spark.sql.hudi.analysis.HoodieSpark32PlusPostAnalysisRule" val spark3PostHocResolution: RuleBuilder = session => instantiateKlass(spark3PostHocResolutionClass, session) @@ -103,7 +122,10 @@ object HoodieAnalysis { } def customOptimizerRules: Seq[RuleBuilder] = { - val optimizerRules = ListBuffer[RuleBuilder]() + val rules: ListBuffer[RuleBuilder] = ListBuffer( + // Default rules + ) + if (HoodieSparkUtils.gteqSpark3_1) { val nestedSchemaPruningClass = if (HoodieSparkUtils.gteqSpark3_3) { @@ -115,8 +137,8 @@ object HoodieAnalysis { "org.apache.spark.sql.execution.datasources.Spark31NestedSchemaPruning" } - val nestedSchemaPruningRule = instantiateKlass(nestedSchemaPruningClass) - optimizerRules += (_ => nestedSchemaPruningRule) + val nestedSchemaPruningRule = ReflectionUtils.loadClass(nestedSchemaPruningClass).asInstanceOf[Rule[LogicalPlan]] + rules += (_ => nestedSchemaPruningRule) } // NOTE: [[HoodiePruneFileSourcePartitions]] is a replica in kind to Spark's @@ -128,15 +150,151 @@ object HoodieAnalysis { // To work this around, we injecting this as the rule that trails pre-CBO, ie it's // - Triggered before CBO, therefore have access to the same stats as CBO // - Precedes actual [[customEarlyScanPushDownRules]] invocation - optimizerRules += (spark => HoodiePruneFileSourcePartitions(spark)) + rules += (spark => HoodiePruneFileSourcePartitions(spark)) + + rules + } + + /** + * This rule adjusts output of the [[LogicalRelation]] resolving int Hudi tables such that all of the + * default Spark resolution could be applied resolving standard Spark SQL commands + * + *
    + *
  • `MERGE INTO ...`
  • + *
  • `INSERT INTO ...`
  • + *
  • `UPDATE ...`
  • + *
+ * + * even though Hudi tables might be carrying meta-fields that have to be ignored during resolution phase. + * + * Spark >= 3.2 bears fully-fledged support for meta-fields and such antics are not required for it: + * we just need to annotate corresponding attributes as "metadata" for Spark to be able to ignore it. + * + * In Spark < 3.2 however, this is worked around by simply removing any meta-fields from the output + * of the [[LogicalRelation]] resolving into Hudi table. Note that, it's a safe operation since we + * actually need to ignore these values anyway + */ + case class AdaptIngestionTargetLogicalRelations(spark: SparkSession) extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = + AnalysisHelper.allowInvokingTransformsInAnalyzer { + plan transformDown { + // NOTE: In case of [[MergeIntoTable]] Hudi tables could be on both sides -- receiving and providing + // the data, as such we have to make sure that we handle both of these cases + case mit @ MergeIntoTable(targetTable, query, _, _, _) => + val updatedTargetTable = targetTable match { + // In the receiving side of the MIT, we can't project meta-field attributes out, + // and instead have to explicitly remove them + case ResolvesToHudiTable(_) => Some(stripMetaFieldsAttributes(targetTable)) + case _ => None + } + + val updatedQuery = query match { + // In the producing side of the MIT, we simply check whether the query will be yielding + // Hudi meta-fields attributes. In cases when it does we simply project them out + // + // NOTE: We have to handle both cases when [[query]] is fully resolved and when it's not, + // since, unfortunately, there's no reliable way for us to control the ordering of the + // application of the rules (during next iteration we might not even reach this rule again), + // therefore we have to make sure projection is handled in a single pass + case ProducesHudiMetaFields(output) => Some(projectOutMetaFieldsAttributes(query, output)) + case _ => None + } + + if (updatedTargetTable.isDefined || updatedQuery.isDefined) { + mit.copy( + targetTable = updatedTargetTable.getOrElse(targetTable), + sourceTable = updatedQuery.getOrElse(query) + ) + } else { + mit + } + + // NOTE: In case of [[InsertIntoStatement]] Hudi tables could be on both sides -- receiving and providing + // the data, as such we have to make sure that we handle both of these cases + case iis @ MatchInsertIntoStatement(targetTable, _, query, _, _) => + val updatedTargetTable = targetTable match { + // In the receiving side of the IIS, we can't project meta-field attributes out, + // and instead have to explicitly remove them + case ResolvesToHudiTable(_) => Some(stripMetaFieldsAttributes(targetTable)) + case _ => None + } + + val updatedQuery = query match { + // In the producing side of the MIT, we simply check whether the query will be yielding + // Hudi meta-fields attributes. In cases when it does we simply project them out + // + // NOTE: We have to handle both cases when [[query]] is fully resolved and when it's not, + // since, unfortunately, there's no reliable way for us to control the ordering of the + // application of the rules (during next iteration we might not even reach this rule again), + // therefore we have to make sure projection is handled in a single pass + case ProducesHudiMetaFields(output) => Some(projectOutMetaFieldsAttributes(query, output)) + case _ => None + } + + if (updatedTargetTable.isDefined || updatedQuery.isDefined) { + sparkAdapter.getCatalystPlanUtils.rebaseInsertIntoStatement(iis, + updatedTargetTable.getOrElse(targetTable), updatedQuery.getOrElse(query)) + } else { + iis + } + + case ut @ UpdateTable(relation @ ResolvesToHudiTable(_), _, _) => + ut.copy(table = projectOutResolvedMetaFieldsAttributes(relation)) + } + } + + private def projectOutMetaFieldsAttributes(plan: LogicalPlan, output: Seq[Attribute]): LogicalPlan = { + if (plan.resolved) { + projectOutResolvedMetaFieldsAttributes(plan) + } else { + projectOutUnresolvedMetaFieldsAttributes(plan, output) + } + } + + private def projectOutUnresolvedMetaFieldsAttributes(plan: LogicalPlan, expected: Seq[Attribute]): LogicalPlan = { + val filtered = expected.attrs.filterNot(attr => isMetaField(attr.name)) + if (filtered != expected) { + Project(filtered.map(attr => UnresolvedAttribute(attr.name)), plan) + } else { + plan + } + } + + private def projectOutResolvedMetaFieldsAttributes(plan: LogicalPlan): LogicalPlan = { + if (plan.output.exists(attr => isMetaField(attr.name))) { + Project(removeMetaFields(plan.output), plan) + } else { + plan + } + } + + private def stripMetaFieldsAttributes(plan: LogicalPlan): LogicalPlan = { + plan transformUp { + case lr: LogicalRelation if lr.output.exists(attr => isMetaField(attr.name)) => + lr.copy(output = removeMetaFields(lr.output)) + } + } + + private object ProducesHudiMetaFields { + + def unapply(plan: LogicalPlan): Option[Seq[Attribute]] = { + val resolved = if (plan.resolved) { + plan + } else { + val analyzer = spark.sessionState.analyzer + analyzer.execute(plan) + } - optimizerRules + if (resolved.output.exists(attr => isMetaField(attr.name))) { + Some(resolved.output) + } else { + None + } + } + } } - /* - // CBO is only supported in Spark >= 3.1.x - def customPreCBORules: Seq[RuleBuilder] = Seq() - */ private def instantiateKlass(klass: String): Rule[LogicalPlan] = { loadClass(klass).asInstanceOf[Rule[LogicalPlan]] } @@ -147,68 +305,95 @@ object HoodieAnalysis { loadClass(klass, Array(classOf[SparkSession]).asInstanceOf[Array[Class[_]]], session) .asInstanceOf[Rule[LogicalPlan]] } + + private[sql] object MatchInsertIntoStatement { + def unapply(plan: LogicalPlan): Option[(LogicalPlan, Map[String, Option[String]], LogicalPlan, Boolean, Boolean)] = + sparkAdapter.getCatalystPlanUtils.unapplyInsertIntoStatement(plan) + } + + private[sql] object ResolvesToHudiTable { + def unapply(plan: LogicalPlan): Option[CatalogTable] = + sparkAdapter.resolveHoodieTable(plan) + } + + private[sql] def failAnalysis(msg: String): Nothing = { + throw new AnalysisException(msg) + } } /** - * Rule for convert the logical plan to command. + * Rule converting *fully-resolved* Spark SQL plans into Hudi's custom implementations * - * @param sparkSession + * NOTE: This is separated out from [[ResolveImplementations]] such that we can apply it + * during earlier stage (resolution), while the [[ResolveImplementations]] is applied at post-hoc + * resolution phase. This is necessary to make sure that [[ResolveImplementationsEarly]] preempts + * execution of the [[DataSourceAnalysis]] stage from Spark which would otherwise convert same commands + * into native Spark implementations (which are not compatible w/ Hudi) + */ +case class ResolveImplementationsEarly() extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = { + plan match { + // Convert to InsertIntoHoodieTableCommand + case iis @ MatchInsertIntoStatement(relation @ ResolvesToHudiTable(_), partition, query, overwrite, _) if query.resolved => + relation match { + // NOTE: In Spark >= 3.2, Hudi relations will be resolved as [[DataSourceV2Relation]]s by default; + // However, currently, fallback will be applied downgrading them to V1 relations, hence + // we need to check whether we could proceed here, or has to wait until fallback rule kicks in + case lr: LogicalRelation => new InsertIntoHoodieTableCommand(lr, query, partition, overwrite) + case _ => iis + } + + // Convert to CreateHoodieTableAsSelectCommand + case ct @ CreateTable(table, mode, Some(query)) + if sparkAdapter.isHoodieTable(table) && ct.query.forall(_.resolved) => + CreateHoodieTableAsSelectCommand(table, mode, query) + + case _ => plan + } + } +} + +/** + * Rule converting *fully-resolved* Spark SQL plans into Hudi's custom implementations + * + * NOTE: This is executed in "post-hoc resolution" phase to make sure all of the commands have + * been resolved prior to that */ -case class HoodieAnalysis(sparkSession: SparkSession) extends Rule[LogicalPlan] - with SparkAdapterSupport { +case class ResolveImplementations() extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { plan match { // Convert to MergeIntoHoodieTableCommand - case m @ MergeIntoTable(target, _, _, _, _) - if m.resolved && sparkAdapter.isHoodieTable(target, sparkSession) => - MergeIntoHoodieTableCommand(m) + case mit @ MergeIntoTable(target @ ResolvesToHudiTable(_), _, _, _, _) if mit.resolved => + MergeIntoHoodieTableCommand(mit) // Convert to UpdateHoodieTableCommand - case u @ UpdateTable(table, _, _) - if u.resolved && sparkAdapter.isHoodieTable(table, sparkSession) => - UpdateHoodieTableCommand(u) + case ut @ UpdateTable(plan @ ResolvesToHudiTable(_), _, _) if ut.resolved => + UpdateHoodieTableCommand(ut) // Convert to DeleteHoodieTableCommand - case d @ DeleteFromTable(table, _) - if d.resolved && sparkAdapter.isHoodieTable(table, sparkSession) => - DeleteHoodieTableCommand(d) - - // Convert to InsertIntoHoodieTableCommand - case l if sparkAdapter.getCatalystPlanUtils.isInsertInto(l) => - val (table, partition, query, overwrite, _) = sparkAdapter.getCatalystPlanUtils.getInsertIntoChildren(l).get - table match { - case relation: LogicalRelation if sparkAdapter.isHoodieTable(relation, sparkSession) => - new InsertIntoHoodieTableCommand(relation, query, partition, overwrite) - case _ => - l - } - - // Convert to CreateHoodieTableAsSelectCommand - case CreateTable(table, mode, Some(query)) - if query.resolved && sparkAdapter.isHoodieTable(table) => - CreateHoodieTableAsSelectCommand(table, mode, query) + case dft @ DeleteFromTable(plan @ ResolvesToHudiTable(_), _) if dft.resolved => + DeleteHoodieTableCommand(dft) // Convert to CompactionHoodieTableCommand - case CompactionTable(table, operation, options) - if table.resolved && sparkAdapter.isHoodieTable(table, sparkSession) => - val tableId = getTableIdentifier(table) - val catalogTable = sparkSession.sessionState.catalog.getTableMetadata(tableId) - CompactionHoodieTableCommand(catalogTable, operation, options) + case ct @ CompactionTable(plan @ ResolvesToHudiTable(table), operation, options) if ct.resolved => + CompactionHoodieTableCommand(table, operation, options) + // Convert to CompactionHoodiePathCommand - case CompactionPath(path, operation, options) => + case cp @ CompactionPath(path, operation, options) if cp.resolved => CompactionHoodiePathCommand(path, operation, options) + // Convert to CompactionShowOnTable - case CompactionShowOnTable(table, limit) - if sparkAdapter.isHoodieTable(table, sparkSession) => - val tableId = getTableIdentifier(table) - val catalogTable = sparkSession.sessionState.catalog.getTableMetadata(tableId) - CompactionShowHoodieTableCommand(catalogTable, limit) + case csot @ CompactionShowOnTable(plan @ ResolvesToHudiTable(table), limit) if csot.resolved => + CompactionShowHoodieTableCommand(table, limit) + // Convert to CompactionShowHoodiePathCommand - case CompactionShowOnPath(path, limit) => + case csop @ CompactionShowOnPath(path, limit) if csop.resolved => CompactionShowHoodiePathCommand(path, limit) + // Convert to HoodieCallProcedureCommand - case c@CallCommand(_, _) => + case c @ CallCommand(_, _) => val procedure: Option[Procedure] = loadProcedure(c.name) val input = buildProcedureArgs(c.args) if (procedure.nonEmpty) { @@ -218,25 +403,21 @@ case class HoodieAnalysis(sparkSession: SparkSession) extends Rule[LogicalPlan] } // Convert to CreateIndexCommand - case CreateIndex(table, indexName, indexType, ignoreIfExists, columns, options, output) - if table.resolved && sparkAdapter.isHoodieTable(table, sparkSession) => - CreateIndexCommand( - getTableIdentifier(table), indexName, indexType, ignoreIfExists, columns, options, output) + case ci @ CreateIndex(plan @ ResolvesToHudiTable(table), indexName, indexType, ignoreIfExists, columns, options, output) => + // TODO need to resolve columns + CreateIndexCommand(table, indexName, indexType, ignoreIfExists, columns, options, output) // Convert to DropIndexCommand - case DropIndex(table, indexName, ignoreIfNotExists, output) - if table.resolved && sparkAdapter.isHoodieTable(table, sparkSession) => - DropIndexCommand(getTableIdentifier(table), indexName, ignoreIfNotExists, output) + case di @ DropIndex(plan @ ResolvesToHudiTable(table), indexName, ignoreIfNotExists, output) if di.resolved => + DropIndexCommand(table, indexName, ignoreIfNotExists, output) // Convert to ShowIndexesCommand - case ShowIndexes(table, output) - if table.resolved && sparkAdapter.isHoodieTable(table, sparkSession) => - ShowIndexesCommand(getTableIdentifier(table), output) + case si @ ShowIndexes(plan @ ResolvesToHudiTable(table), output) if si.resolved => + ShowIndexesCommand(table, output) // Covert to RefreshCommand - case RefreshIndex(table, indexName, output) - if table.resolved && sparkAdapter.isHoodieTable(table, sparkSession) => - RefreshIndexCommand(getTableIdentifier(table), indexName, output) + case ri @ RefreshIndex(plan @ ResolvesToHudiTable(table), indexName, output) if ri.resolved => + RefreshIndexCommand(table, indexName, output) case _ => plan } @@ -276,337 +457,11 @@ case class HoodieAnalysis(sparkSession: SparkSession) extends Rule[LogicalPlan] } } -/** - * Rule for resolve hoodie's extended syntax or rewrite some logical plan. - * - * @param sparkSession - */ -case class HoodieResolveReferences(sparkSession: SparkSession) extends Rule[LogicalPlan] - with SparkAdapterSupport { - private lazy val analyzer = sparkSession.sessionState.analyzer - - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { - // Resolve merge into - case mergeInto @ MergeIntoTable(target, source, mergeCondition, matchedActions, notMatchedActions) - if sparkAdapter.isHoodieTable(target, sparkSession) && target.resolved => - val resolver = sparkSession.sessionState.conf.resolver - val resolvedSource = analyzer.execute(source) - try { - analyzer.checkAnalysis(resolvedSource) - } catch { - case e: AnalysisException => - val ae = new AnalysisException(e.message, e.line, e.startPosition, Option(resolvedSource)) - ae.setStackTrace(e.getStackTrace) - throw ae - } - - def isInsertOrUpdateStar(assignments: Seq[Assignment]): Boolean = { - if (assignments.isEmpty) { - true - } else { - // This is a Hack for test if it is "update set *" or "insert *" for spark3. - // As spark3's own ResolveReference will append first five columns of the target - // table(which is the hoodie meta fields) to the assignments for "update set *" and - // "insert *", so we test if the first five assignmentFieldNames is the meta fields - // to judge if it is "update set *" or "insert *". - // We can do this because under the normal case, we should not allow to update or set - // the hoodie's meta field in sql statement, it is a system field, cannot set the value - // by user. - if (HoodieSparkUtils.isSpark3) { - val resolvedAssignments = assignments.map { assign => - val resolvedKey = assign.key match { - case c if !c.resolved => - resolveExpressionFrom(target)(c) - case o => o - } - Assignment(resolvedKey, null) - } - val assignmentFieldNames = resolvedAssignments.map(_.key).map { - case attr: AttributeReference => - attr.name - case _ => "" - }.toArray - val metaFields = HoodieRecord.HOODIE_META_COLUMNS.asScala - if (assignmentFieldNames.take(metaFields.length).mkString(",").startsWith(metaFields.mkString(","))) { - true - } else { - false - } - } else { - false - } - } - } - - def resolveConditionAssignments(condition: Option[Expression], - assignments: Seq[Assignment]): (Option[Expression], Seq[Assignment]) = { - val resolvedCondition = condition.map(resolveExpressionFrom(resolvedSource)(_)) - val resolvedAssignments = if (isInsertOrUpdateStar(assignments)) { - // assignments is empty means insert * or update set * - val resolvedSourceOutput = resolvedSource.output.filter(attr => !HoodieSqlCommonUtils.isMetaField(attr.name)) - val targetOutput = target.output.filter(attr => !HoodieSqlCommonUtils.isMetaField(attr.name)) - val resolvedSourceColumnNames = resolvedSourceOutput.map(_.name) - - if(targetOutput.filter(attr => resolvedSourceColumnNames.exists(resolver(_, attr.name))).equals(targetOutput)){ - //If sourceTable's columns contains all targetTable's columns, - //We fill assign all the source fields to the target fields by column name matching. - targetOutput.map(targetAttr => { - val sourceAttr = resolvedSourceOutput.find(f => resolver(f.name, targetAttr.name)).get - Assignment(targetAttr, sourceAttr) - }) - } else { - // We fill assign all the source fields to the target fields by order. - targetOutput - .zip(resolvedSourceOutput) - .map { case (targetAttr, sourceAttr) => Assignment(targetAttr, sourceAttr) } - } - } else { - // For Spark3.2, InsertStarAction/UpdateStarAction's assignments will contain the meta fields. - val withoutMetaAttrs = assignments.filterNot{ assignment => - if (assignment.key.isInstanceOf[Attribute]) { - HoodieSqlCommonUtils.isMetaField(assignment.key.asInstanceOf[Attribute].name) - } else { - false - } - } - withoutMetaAttrs.map { assignment => - val resolvedKey = resolveExpressionFrom(target)(assignment.key) - val resolvedValue = resolveExpressionFrom(resolvedSource, Some(target))(assignment.value) - Assignment(resolvedKey, resolvedValue) - } - } - (resolvedCondition, resolvedAssignments) - } - - // Resolve the merge condition - val resolvedMergeCondition = resolveExpressionFrom(resolvedSource, Some(target))(mergeCondition) - - // Resolve the matchedActions - val resolvedMatchedActions = matchedActions.map { - case UpdateAction(condition, assignments) => - val (resolvedCondition, resolvedAssignments) = - resolveConditionAssignments(condition, assignments) - - // Get the target table type and pre-combine field. - val targetTableId = getMergeIntoTargetTableId(mergeInto) - val targetTable = - sparkSession.sessionState.catalog.getTableMetadata(targetTableId) - val tblProperties = targetTable.storage.properties ++ targetTable.properties - val targetTableType = HoodieOptionConfig.getTableType(tblProperties) - val preCombineField = HoodieOptionConfig.getPreCombineField(tblProperties) - - // Get the map of target attribute to value of the update assignments. - val target2Values = resolvedAssignments.map { - case Assignment(attr: AttributeReference, value) => - attr.name -> value - case o => throw new IllegalArgumentException(s"Assignment key must be an attribute, current is: ${o.key}") - }.toMap - - // Validate if there are incorrect target attributes. - val targetColumnNames = removeMetaFields(target.output).map(_.name) - val unKnowTargets = target2Values.keys - .filterNot(name => targetColumnNames.exists(resolver(_, name))) - if (unKnowTargets.nonEmpty) { - throw new AnalysisException(s"Cannot find target attributes: ${unKnowTargets.mkString(",")}.") - } - - // Fill the missing target attribute in the update action for COW table to support partial update. - // e.g. If the update action missing 'id' attribute, we fill a "id = target.id" to the update action. - val newAssignments = removeMetaFields(target.output) - .map(attr => { - val valueOption = target2Values.find(f => resolver(f._1, attr.name)) - // TODO support partial update for MOR. - if (valueOption.isEmpty && targetTableType == MOR_TABLE_TYPE_OPT_VAL) { - throw new AnalysisException(s"Missing specify the value for target field: '${attr.name}' in merge into update action" + - s" for MOR table. Currently we cannot support partial update for MOR," + - s" please complete all the target fields just like '...update set id = s0.id, name = s0.name ....'") - } - if (preCombineField.isDefined && preCombineField.get.equalsIgnoreCase(attr.name) - && valueOption.isEmpty) { - throw new AnalysisException(s"Missing specify value for the preCombineField:" + - s" ${preCombineField.get} in merge-into update action. You should add" + - s" '... update set ${preCombineField.get} = xx....' to the when-matched clause.") - } - Assignment(attr, if (valueOption.isEmpty) attr else valueOption.get._2) - }) - UpdateAction(resolvedCondition, newAssignments) - case DeleteAction(condition) => - val resolvedCondition = condition.map(resolveExpressionFrom(resolvedSource)(_)) - DeleteAction(resolvedCondition) - case action: MergeAction => - // SPARK-34962: use UpdateStarAction as the explicit representation of * in UpdateAction. - // So match and covert this in Spark3.2 env. - val (resolvedCondition, resolvedAssignments) = - resolveConditionAssignments(action.condition, Seq.empty) - UpdateAction(resolvedCondition, resolvedAssignments) - } - // Resolve the notMatchedActions - val resolvedNotMatchedActions = notMatchedActions.map { - case InsertAction(condition, assignments) => - val (resolvedCondition, resolvedAssignments) = - resolveConditionAssignments(condition, assignments) - InsertAction(resolvedCondition, resolvedAssignments) - case action: MergeAction => - // SPARK-34962: use InsertStarAction as the explicit representation of * in InsertAction. - // So match and covert this in Spark3.2 env. - val (resolvedCondition, resolvedAssignments) = - resolveConditionAssignments(action.condition, Seq.empty) - InsertAction(resolvedCondition, resolvedAssignments) - } - // Return the resolved MergeIntoTable - MergeIntoTable(target, resolvedSource, resolvedMergeCondition, - resolvedMatchedActions, resolvedNotMatchedActions) - - // Resolve update table - case UpdateTable(table, assignments, condition) - if sparkAdapter.isHoodieTable(table, sparkSession) && table.resolved => - // Resolve condition - val resolvedCondition = condition.map(resolveExpressionFrom(table)(_)) - // Resolve assignments - val resolvedAssignments = assignments.map(assignment => { - val resolvedKey = resolveExpressionFrom(table)(assignment.key) - val resolvedValue = resolveExpressionFrom(table)(assignment.value) - Assignment(resolvedKey, resolvedValue) - }) - // Return the resolved UpdateTable - UpdateTable(table, resolvedAssignments, resolvedCondition) - - // Resolve Delete Table - case dft @ DeleteFromTable(table, condition) - if sparkAdapter.isHoodieTable(table, sparkSession) && table.resolved => - val resolveExpression = resolveExpressionFrom(table, None)(_) - sparkAdapter.resolveDeleteFromTable(dft, resolveExpression) - - // Append the meta field to the insert query to walk through the validate for the - // number of insert fields with the number of the target table fields. - case l if sparkAdapter.getCatalystPlanUtils.isInsertInto(l) => - val (table, partition, query, overwrite, ifPartitionNotExists) = - sparkAdapter.getCatalystPlanUtils.getInsertIntoChildren(l).get - - if (sparkAdapter.isHoodieTable(table, sparkSession) && query.resolved && - !containUnResolvedStar(query) && - !checkAlreadyAppendMetaField(query)) { - val metaFields = HoodieRecord.HOODIE_META_COLUMNS.asScala.map( - Alias(Literal.create(null, StringType), _)()).toArray[NamedExpression] - val newQuery = query match { - case project: Project => - val withMetaFieldProjects = - metaFields ++ project.projectList - // Append the meta fields to the insert query. - Project(withMetaFieldProjects, project.child) - case _ => - val withMetaFieldProjects = metaFields ++ query.output - Project(withMetaFieldProjects, query) - } - sparkAdapter.getCatalystPlanUtils.createInsertInto(table, partition, newQuery, overwrite, ifPartitionNotExists) - } else { - l - } - - case l if sparkAdapter.getCatalystPlanUtils.isRelationTimeTravel(l) => - val (plan: UnresolvedRelation, timestamp, version) = - sparkAdapter.getCatalystPlanUtils.getRelationTimeTravel(l).get - - if (timestamp.isEmpty && version.nonEmpty) { - throw new AnalysisException( - "version expression is not supported for time travel") - } - - val tableIdentifier = sparkAdapter.getCatalystPlanUtils.toTableIdentifier(plan) - if (sparkAdapter.isHoodieTable(tableIdentifier, sparkSession)) { - val hoodieCatalogTable = HoodieCatalogTable(sparkSession, tableIdentifier) - val table = hoodieCatalogTable.table - val pathOption = table.storage.locationUri.map("path" -> CatalogUtils.URIToString(_)) - val instantOption = Map( - DataSourceReadOptions.TIME_TRAVEL_AS_OF_INSTANT.key -> timestamp.get.toString()) - val dataSource = - DataSource( - sparkSession, - userSpecifiedSchema = if (table.schema.isEmpty) None else Some(table.schema), - partitionColumns = table.partitionColumnNames, - bucketSpec = table.bucketSpec, - className = table.provider.get, - options = table.storage.properties ++ pathOption ++ instantOption, - catalogTable = Some(table)) - - LogicalRelation(dataSource.resolveRelation(checkFilesExist = false), table) - } else { - l - } - - case p => p - } - - private def containUnResolvedStar(query: LogicalPlan): Boolean = { - query match { - case project: Project => project.projectList.exists(_.isInstanceOf[UnresolvedStar]) - case _ => false - } - } - - /** - * Check if the the query of insert statement has already append the meta fields to avoid - * duplicate append. - * - * @param query - * @return - */ - private def checkAlreadyAppendMetaField(query: LogicalPlan): Boolean = { - query.output.take(HoodieRecord.HOODIE_META_COLUMNS.size()) - .filter(isMetaField) - .map { - case AttributeReference(name, _, _, _) => name.toLowerCase - case other => throw new IllegalArgumentException(s"$other should not be a hoodie meta field") - }.toSet == HoodieRecord.HOODIE_META_COLUMNS.asScala.toSet - } - - private def isMetaField(exp: Expression): Boolean = { - val metaFields = HoodieRecord.HOODIE_META_COLUMNS.asScala.toSet - exp match { - case Alias(_, name) if metaFields.contains(name.toLowerCase) => true - case AttributeReference(name, _, _, _) if metaFields.contains(name.toLowerCase) => true - case _=> false - } - } - - /** - * Resolve the expression. - * 1、 Fake a a project for the expression based on the source plan - * 2、 Resolve the fake project - * 3、 Get the resolved expression from the faked project - * @param left The left source plan for the expression. - * @param right The right source plan for the expression. - * @param expression The expression to resolved. - * @return The resolved expression. - */ - private def resolveExpressionFrom(left: LogicalPlan, right: Option[LogicalPlan] = None) - (expression: Expression): Expression = { - // Fake a project for the expression based on the source plan. - val fakeProject = if (right.isDefined) { - Project(Seq(Alias(expression, "_c0")()), - sparkAdapter.getCatalystPlanUtils.createJoin(left, right.get, Inner)) - } else { - Project(Seq(Alias(expression, "_c0")()), - left) - } - // Resolve the fake project - val resolvedProject = - analyzer.ResolveReferences.apply(fakeProject).asInstanceOf[Project] - val unResolvedAttrs = resolvedProject.projectList.head.collect { - case attr: UnresolvedAttribute => attr - } - if (unResolvedAttrs.nonEmpty) { - throw new AnalysisException(s"Cannot resolve ${unResolvedAttrs.mkString(",")} in " + - s"${expression.sql}, the input " + s"columns is: [${fakeProject.child.output.mkString(", ")}]") - } - // Fetch the resolved expression from the fake project. - resolvedProject.projectList.head.asInstanceOf[Alias].child - } -} - /** * Rule for rewrite some spark commands to hudi's implementation. * @param sparkSession + * + * TODO merge w/ ResolveImplementations */ case class HoodiePostAnalysisRule(sparkSession: SparkSession) extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodiePruneFileSourcePartitions.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodiePruneFileSourcePartitions.scala index 46cb931a59b6d..26ef2e0188c89 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodiePruneFileSourcePartitions.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodiePruneFileSourcePartitions.scala @@ -41,7 +41,7 @@ case class HoodiePruneFileSourcePartitions(spark: SparkSession) extends Rule[Log override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { case op @ PhysicalOperation(projects, filters, lr @ LogicalRelation(HoodieRelationMatcher(fileIndex), _, _, _)) - if sparkAdapter.isHoodieTable(lr, spark) && !fileIndex.hasPredicatesPushedDown => + if !fileIndex.hasPredicatesPushedDown => val deterministicFilters = filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)) val normalizedFilters = exprUtils.normalizeExprs(deterministicFilters, lr.output) diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/DeleteHoodieTableCommand.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/DeleteHoodieTableCommand.scala index e1dc8daa4ca97..004a102287671 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/DeleteHoodieTableCommand.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/DeleteHoodieTableCommand.scala @@ -20,35 +20,44 @@ package org.apache.spark.sql.hudi.command import org.apache.hudi.SparkAdapterSupport import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog.HoodieCatalogTable -import org.apache.spark.sql.catalyst.plans.logical.DeleteFromTable -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.hudi.HoodieSqlCommonUtils._ +import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, Filter} import org.apache.spark.sql.hudi.ProvidesHoodieConfig +import org.apache.spark.sql.hudi.command.HoodieLeafRunnableCommand.stripMetaFieldAttributes -case class DeleteHoodieTableCommand(deleteTable: DeleteFromTable) extends HoodieLeafRunnableCommand - with SparkAdapterSupport with ProvidesHoodieConfig { +case class DeleteHoodieTableCommand(dft: DeleteFromTable) extends HoodieLeafRunnableCommand + with SparkAdapterSupport + with ProvidesHoodieConfig { - private val table = deleteTable.table + override def run(sparkSession: SparkSession): Seq[Row] = { + val catalogTable = sparkAdapter.resolveHoodieTable(dft.table) + .map(HoodieCatalogTable(sparkSession, _)) + .get - private val tableId = getTableIdentifier(table) + val tableId = catalogTable.table.qualifiedName - override def run(sparkSession: SparkSession): Seq[Row] = { - logInfo(s"start execute delete command for $tableId") + logInfo(s"Executing 'DELETE FROM' command for $tableId") + + val condition = sparkAdapter.extractDeleteCondition(dft) + + val targetLogicalPlan = stripMetaFieldAttributes(dft.table) + val filteredPlan = if (condition != null) { + Filter(condition, targetLogicalPlan) + } else { + targetLogicalPlan + } - // Remove meta fields from the data frame - var df = removeMetaFields(Dataset.ofRows(sparkSession, table)) - val condition = sparkAdapter.extractDeleteCondition(deleteTable) - if (condition != null) df = df.filter(Column(condition)) + val config = buildHoodieDeleteTableConfig(catalogTable, sparkSession) + val df = Dataset.ofRows(sparkSession, filteredPlan) - val hoodieCatalogTable = HoodieCatalogTable(sparkSession, tableId) - val config = buildHoodieDeleteTableConfig(hoodieCatalogTable, sparkSession) - df.write - .format("hudi") + df.write.format("hudi") .mode(SaveMode.Append) .options(config) .save() - sparkSession.catalog.refreshTable(tableId.unquotedString) - logInfo(s"Finish execute delete command for $tableId") + + sparkSession.catalog.refreshTable(tableId) + + logInfo(s"Finished executing 'DELETE FROM' command for $tableId") + Seq.empty[Row] } } diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/IndexCommands.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/IndexCommands.scala index 8a3b5630b67bb..8ac0831a22f5a 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/IndexCommands.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/IndexCommands.scala @@ -25,6 +25,7 @@ import org.apache.hudi.HoodieConversionUtils.toScalaOption import org.apache.hudi.common.table.HoodieTableMetaClient import org.apache.hudi.secondary.index.SecondaryIndexManager import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} import org.apache.spark.sql.hudi.HoodieSqlCommonUtils.getTableLocation @@ -34,16 +35,16 @@ import java.util import scala.collection.JavaConverters.{collectionAsScalaIterableConverter, mapAsJavaMapConverter} -case class CreateIndexCommand( - tableId: TableIdentifier, - indexName: String, - indexType: String, - ignoreIfExists: Boolean, - columns: Seq[(Attribute, Map[String, String])], - options: Map[String, String], - override val output: Seq[Attribute]) extends IndexBaseCommand { +case class CreateIndexCommand(table: CatalogTable, + indexName: String, + indexType: String, + ignoreIfExists: Boolean, + columns: Seq[(Attribute, Map[String, String])], + options: Map[String, String], + override val output: Seq[Attribute]) extends IndexBaseCommand { override def run(sparkSession: SparkSession): Seq[Row] = { + val tableId = table.identifier val metaClient = createHoodieTableMetaClient(tableId, sparkSession) val columnsMap: java.util.LinkedHashMap[String, java.util.Map[String, String]] = new util.LinkedHashMap[String, java.util.Map[String, String]]() @@ -62,13 +63,13 @@ case class CreateIndexCommand( } } -case class DropIndexCommand( - tableId: TableIdentifier, - indexName: String, - ignoreIfNotExists: Boolean, - override val output: Seq[Attribute]) extends IndexBaseCommand { +case class DropIndexCommand(table: CatalogTable, + indexName: String, + ignoreIfNotExists: Boolean, + override val output: Seq[Attribute]) extends IndexBaseCommand { override def run(sparkSession: SparkSession): Seq[Row] = { + val tableId = table.identifier val metaClient = createHoodieTableMetaClient(tableId, sparkSession) SecondaryIndexManager.getInstance().drop(metaClient, indexName, ignoreIfNotExists) @@ -82,12 +83,11 @@ case class DropIndexCommand( } } -case class ShowIndexesCommand( - tableId: TableIdentifier, - override val output: Seq[Attribute]) extends IndexBaseCommand { +case class ShowIndexesCommand(table: CatalogTable, + override val output: Seq[Attribute]) extends IndexBaseCommand { override def run(sparkSession: SparkSession): Seq[Row] = { - val metaClient = createHoodieTableMetaClient(tableId, sparkSession) + val metaClient = createHoodieTableMetaClient(table.identifier, sparkSession) val secondaryIndexes = SecondaryIndexManager.getInstance().show(metaClient) val mapper = getObjectMapper @@ -109,13 +109,12 @@ case class ShowIndexesCommand( } } -case class RefreshIndexCommand( - tableId: TableIdentifier, - indexName: String, - override val output: Seq[Attribute]) extends IndexBaseCommand { +case class RefreshIndexCommand(table: CatalogTable, + indexName: String, + override val output: Seq[Attribute]) extends IndexBaseCommand { override def run(sparkSession: SparkSession): Seq[Row] = { - val metaClient = createHoodieTableMetaClient(tableId, sparkSession) + val metaClient = createHoodieTableMetaClient(table.identifier, sparkSession) SecondaryIndexManager.getInstance().refresh(metaClient, indexName) Seq.empty } diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/InsertIntoHoodieTableCommand.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/InsertIntoHoodieTableCommand.scala index f07611ad0198a..35f00ff95e6d3 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/InsertIntoHoodieTableCommand.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/InsertIntoHoodieTableCommand.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.hudi.ProvidesHoodieConfig import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{StringType, StructField, StructType} import org.apache.spark.sql._ +import org.apache.spark.sql.hudi.command.HoodieLeafRunnableCommand.stripMetaFieldAttributes /** * Command for insert into Hudi table. @@ -139,7 +140,7 @@ object InsertIntoHoodieTableCommand extends Logging with ProvidesHoodieConfig wi val staticPartitionValues = filterStaticPartitionValues(partitionsSpec) // Make sure we strip out meta-fields from the incoming dataset (these will have to be discarded anyway) - val cleanedQuery = stripMetaFields(query) + val cleanedQuery = stripMetaFieldAttributes(query) // To validate and align properly output of the query, we simply filter out partition columns with already // provided static values from the table's schema // @@ -196,7 +197,7 @@ object InsertIntoHoodieTableCommand extends Logging with ProvidesHoodieConfig wi .filter(pf => staticPartitionValues.contains(pf.name)) .map(pf => { val staticPartitionValue = staticPartitionValues(pf.name) - val castExpr = castIfNeeded(Literal.create(staticPartitionValue), pf.dataType, conf) + val castExpr = castIfNeeded(Literal.create(staticPartitionValue), pf.dataType) Alias(castExpr, pf.name)() }) @@ -214,15 +215,6 @@ object InsertIntoHoodieTableCommand extends Logging with ProvidesHoodieConfig wi } } - def stripMetaFields(query: LogicalPlan): LogicalPlan = { - val filteredOutput = query.output.filterNot(attr => isMetaField(attr.name)) - if (filteredOutput == query.output) { - query - } else { - Project(filteredOutput, query) - } - } - private def filterStaticPartitionValues(partitionsSpec: Map[String, Option[String]]): Map[String, String] = partitionsSpec.filter(p => p._2.isDefined).mapValues(_.get) } diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala index 54f2534b4a07e..9c39d82c39cbd 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala @@ -28,97 +28,125 @@ import org.apache.hudi.config.HoodieWriteConfig.{AVRO_SCHEMA_VALIDATE_ENABLE, TB import org.apache.hudi.exception.HoodieException import org.apache.hudi.hive.HiveSyncConfigHolder import org.apache.hudi.sync.common.HoodieSyncConfig +import org.apache.hudi.util.JFunction.scalaFunction1Noop import org.apache.hudi.{AvroConversionUtils, DataSourceWriteOptions, HoodieSparkSqlWriter, SparkAdapterSupport} -import org.apache.spark.sql.HoodieCatalystExpressionUtils.MatchCast +import org.apache.spark.sql.HoodieCatalystExpressionUtils.{MatchCast, attributeEquals} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.catalog.HoodieCatalogTable -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, BoundReference, Cast, EqualTo, Expression, Literal} +import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReference +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, BoundReference, EqualTo, Expression, Literal, NamedExpression, PredicateHelper} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.hudi.HoodieSqlCommonUtils._ -import org.apache.spark.sql.hudi.HoodieSqlUtils.getMergeIntoTargetTableId +import org.apache.spark.sql.hudi.analysis.HoodieAnalysis.failAnalysis import org.apache.spark.sql.hudi.ProvidesHoodieConfig.combineOptions -import org.apache.spark.sql.hudi.command.MergeIntoHoodieTableCommand.CoercedAttributeReference +import org.apache.spark.sql.hudi.command.MergeIntoHoodieTableCommand.{CoercedAttributeReference, encodeAsBase64String, stripCasting, toStructType} import org.apache.spark.sql.hudi.command.payload.ExpressionPayload import org.apache.spark.sql.hudi.command.payload.ExpressionPayload._ import org.apache.spark.sql.hudi.ProvidesHoodieConfig -import org.apache.spark.sql.types.{BooleanType, StructType} +import org.apache.spark.sql.types.{BooleanType, StructField, StructType} import java.util.Base64 /** - * The Command for hoodie MergeIntoTable. - * The match on condition must contain the row key fields currently, so that we can use Hoodie - * Index to speed up the performance. + * Hudi's implementation of the {@code MERGE INTO} (MIT) Spark SQL statement. * - * The main algorithm: + * NOTE: That this implementation is restricted in a some aspects to accommodate for Hudi's crucial + * constraint (of requiring every record to bear unique primary-key): merging condition ([[mergeCondition]]) + * is currently can only (and must) reference target table's primary-key columns (this is necessary to + * leverage Hudi's upserting capabilities including Indexes) * - * We pushed down all the matched and not matched (condition, assignment) expression pairs to the - * ExpressionPayload. And the matched (condition, assignment) expression pairs will execute in the - * ExpressionPayload#combineAndGetUpdateValue to compute the result record, while the not matched - * expression pairs will execute in the ExpressionPayload#getInsertValue. + * Following algorithm is applied: * - * For Mor table, it is a litter complex than this. The matched record also goes through the getInsertValue - * and write append to the log. So the update actions & insert actions should process by the same - * way. We pushed all the update actions & insert actions together to the - * ExpressionPayload#getInsertValue. + *
    + *
  1. Incoming batch ([[sourceTable]]) is reshaped such that it bears correspondingly: + * a) (required) "primary-key" column as well as b) (optional) "pre-combine" column; this is + * required since MIT statements does not restrict [[sourceTable]]s schema to be aligned w/ the + * [[targetTable]]s one, while Hudi's upserting flow expects such columns to be present
  2. * + *
  3. After reshaping we're writing [[sourceTable]] as a normal batch using Hudi's upserting + * sequence, where special [[ExpressionPayload]] implementation of the [[HoodieRecordPayload]] + * is used allowing us to execute updating, deleting and inserting clauses like following:
  4. + * + *
      + *
    1. All the matched {@code WHEN MATCHED AND ... THEN (DELETE|UPDATE ...)} conditional clauses + * will produce [[(condition, expression)]] tuples that will be executed w/in the + * [[ExpressionPayload#combineAndGetUpdateValue]] against existing (from [[targetTable]]) and + * incoming (from [[sourceTable]]) records producing the updated one;
    2. + * + *
    3. Not matched {@code WHEN NOT MATCHED AND ... THEN INSERT ...} conditional clauses + * will produce [[(condition, expression)]] tuples that will be executed w/in [[ExpressionPayload#getInsertValue]] + * against incoming records producing ones to be inserted into target table;
    4. + *
    + *
+ * + * TODO explain workflow for MOR tables */ case class MergeIntoHoodieTableCommand(mergeInto: MergeIntoTable) extends HoodieLeafRunnableCommand - with SparkAdapterSupport with ProvidesHoodieConfig { + with SparkAdapterSupport + with ProvidesHoodieConfig + with PredicateHelper { private var sparkSession: SparkSession = _ - /** - * The target table identify. - */ - private lazy val targetTableIdentify: TableIdentifier = getMergeIntoTargetTableId(mergeInto) - - /** - * The target table schema without hoodie meta fields. - */ - private var sourceDFOutput = mergeInto.sourceTable.output.filter(attr => !isMetaField(attr.name)) - /** * The target table schema without hoodie meta fields. */ - private lazy val targetTableSchemaWithoutMetaFields = + private lazy val targetTableSchema = removeMetaFields(mergeInto.targetTable.schema).fields - private lazy val hoodieCatalogTable = HoodieCatalogTable(sparkSession, targetTableIdentify) + private lazy val hoodieCatalogTable = sparkAdapter.resolveHoodieTable(mergeInto.targetTable) match { + case Some(catalogTable) => HoodieCatalogTable(sparkSession, catalogTable) + case _ => + failAnalysis(s"Failed to resolve MERGE INTO statement into the Hudi table. Got instead: ${mergeInto.targetTable}") + } private lazy val targetTableType = hoodieCatalogTable.tableTypeName /** + * Mapping of the Merge-Into-Table (MIT) command's [[targetTable]] attribute into + * corresponding expression (involving reference from the [[sourceTable]]) from the MIT + * [[mergeCondition]] condition. For ex, + *
MERGE INTO ... ON t.id = s.s_id AND t.name = lowercase(s.s_name)
+ * will produce + *
Map("id" -> "s_id", "name" -> lowercase("s_name")
+ * + * Such mapping is used to be able to properly merge the record in the incoming batch against + * existing table. Let's take following merge statement as an example: * - * Return a map of target key to the source expression from the Merge-On Condition. - * e.g. merge on t.id = s.s_id AND t.name = s.s_name, we return - * Map("id" -> "s_id", "name" ->"s_name") - * TODO Currently Non-equivalent conditions are not supported. + *
+   * MERGE INTO ... AS target USING ... AS source
+   * ON target.id = lowercase(source.id) ...
+   * 
+ * + * To be able to leverage Hudi's engine to merge an incoming dataset against the existing table + * we will have to make sure that both [[source]] and [[target]] tables have the *same* + * "primary-key" and "pre-combine" columns. Since actual MIT condition might be leveraging an arbitrary + * expression involving [[source]] column(s), we will have to add "phony" column matching the + * primary-key one of the target table. */ - private lazy val targetKey2SourceExpression: Map[String, Expression] = { - val resolver = sparkSession.sessionState.conf.resolver - val conditions = splitByAnd(mergeInto.mergeCondition) - val allEqs = conditions.forall(p => p.isInstanceOf[EqualTo]) - if (!allEqs) { - throw new IllegalArgumentException("Non-Equal condition is not support for Merge " + - s"Into Statement: ${mergeInto.mergeCondition.sql}") + private lazy val primaryKeyAttributeToConditionExpression: Seq[(Attribute, Expression)] = { + val conditions = splitConjunctivePredicates(mergeInto.mergeCondition) + if (!conditions.forall(p => p.isInstanceOf[EqualTo])) { + throw new AnalysisException(s"Currently only equality predicates are supported in MERGE INTO statement " + + s"(provided ${mergeInto.mergeCondition.sql}") } - val targetAttrs = mergeInto.targetTable.output - val cleanedConditions = conditions.map(_.asInstanceOf[EqualTo]).map { - // Here we're unraveling superfluous casting of expressions on both sides of the matched-on condition, - // in case both of them are casted to the same type (which might be result of either explicit casting - // from the user, or auto-casting performed by Spark for type coercion), which has potential - // potential of rendering the whole operation as invalid (check out HUDI-4861 for more details) - case EqualTo(MatchCast(leftExpr, leftCastTargetType, _, _), MatchCast(rightExpr, rightCastTargetType, _, _)) - if leftCastTargetType.sameType(rightCastTargetType) => EqualTo(leftExpr, rightExpr) + val resolver = sparkSession.sessionState.analyzer.resolver + val primaryKeyField = hoodieCatalogTable.tableConfig.getRecordKeyFieldProp - case c => c - } + val targetAttrs = mergeInto.targetTable.outputSet val exprUtils = sparkAdapter.getCatalystExpressionUtils + // Here we're unraveling superfluous casting of expressions on both sides of the matched-on condition, + // in case both of them are casted to the same type (which might be result of either explicit casting + // from the user, or auto-casting performed by Spark for type coercion), which has potential + // of rendering the whole operation as invalid. This is the case b/c we're leveraging Hudi's internal + // flow of matching records and therefore will be matching source and target table's primary-key values + // as they are w/o the ability of transforming them w/ custom expressions (unlike in vanilla Spark flow). + // + // Check out HUDI-4861 for more details + val cleanedConditions = conditions.map(_.asInstanceOf[EqualTo]).map(stripCasting) + // Expressions of the following forms are supported: // `target.id = ` (or ` = target.id`) // `cast(target.id, ...) = ` (or ` = cast(target.id, ...)`) @@ -127,164 +155,189 @@ case class MergeIntoHoodieTableCommand(mergeInto: MergeIntoTable) extends Hoodie // target table side (since we're gonna be matching against primary-key column as is) expression // on the opposite side of the comparison should be cast-able to the primary-key column's data-type // t/h "up-cast" (ie w/o any loss in precision) - val target2Source = cleanedConditions.map { - case EqualTo(CoercedAttributeReference(attr), expr) - if targetAttrs.exists(f => attributeEqual(f, attr, resolver)) => - if (exprUtils.canUpCast(expr.dataType, attr.dataType)) { - targetAttrs.find(f => resolver(f.name, attr.name)).get.name -> - castIfNeeded(expr, attr.dataType, sparkSession.sqlContext.conf) - } else { - throw new AnalysisException(s"Invalid MERGE INTO matching condition: ${expr.sql}: " - + s"can't cast ${expr.sql} (of ${expr.dataType}) to ${attr.dataType}") - } + val targetAttr2ConditionExpressions = cleanedConditions.map { + case EqualTo(CoercedAttributeReference(attr), expr) if targetAttrs.exists(f => attributeEquals(f, attr)) => + if (exprUtils.canUpCast(expr.dataType, attr.dataType)) { + // NOTE: It's critical we reference output attribute here and not the one from condition + val targetAttr = targetAttrs.find(f => attributeEquals(f, attr)).get + targetAttr -> castIfNeeded(expr, attr.dataType) + } else { + throw new AnalysisException(s"Invalid MERGE INTO matching condition: ${expr.sql}: " + + s"can't cast ${expr.sql} (of ${expr.dataType}) to ${attr.dataType}") + } - case EqualTo(expr, CoercedAttributeReference(attr)) - if targetAttrs.exists(f => attributeEqual(f, attr, resolver)) => - if (exprUtils.canUpCast(expr.dataType, attr.dataType)) { - targetAttrs.find(f => resolver(f.name, attr.name)).get.name -> - castIfNeeded(expr, attr.dataType, sparkSession.sqlContext.conf) - } else { - throw new AnalysisException(s"Invalid MERGE INTO matching condition: ${expr.sql}: " - + s"can't cast ${expr.sql} (of ${expr.dataType}) to ${attr.dataType}") - } + case EqualTo(expr, CoercedAttributeReference(attr)) if targetAttrs.exists(f => attributeEquals(f, attr)) => + if (exprUtils.canUpCast(expr.dataType, attr.dataType)) { + // NOTE: It's critical we reference output attribute here and not the one from condition + val targetAttr = targetAttrs.find(f => attributeEquals(f, attr)).get + targetAttr -> castIfNeeded(expr, attr.dataType) + } else { + throw new AnalysisException(s"Invalid MERGE INTO matching condition: ${expr.sql}: " + + s"can't cast ${expr.sql} (of ${expr.dataType}) to ${attr.dataType}") + } case expr => throw new AnalysisException(s"Invalid MERGE INTO matching condition: `${expr.sql}`: " + "expected condition should be 'target.id = ', e.g. " + "`t.id = s.id` or `t.id = cast(s.id, ...)`") - }.toMap + } - target2Source + targetAttr2ConditionExpressions.collect { + case (attr, expr) if resolver(attr.name, primaryKeyField) => + // NOTE: Here we validate that condition expression involving primary-key column(s) is a simple + // attribute-reference expression (possibly wrapped into a cast). This is necessary to disallow + // statements like following + // + // MERGE INTO ... AS t USING ( + // SELECT ... FROM ... AS s + // ) + // ON t.id = s.id + 1 + // WHEN MATCHED THEN UPDATE * + // + // Which (in the current design) could result in a primary key of the record being modified, + // which is not allowed. + if (!resolvesToSourceAttribute(expr)) { + throw new AnalysisException("Only simple conditions of the form `t.id = s.id` are allowed on the " + + s"primary-key column. Found `${attr.sql} = ${expr.sql}`") + } + + (attr, expr) + } } /** - * Get the mapping of target preCombineField to the source expression. + * Please check description for [[primaryKeyAttributeToConditionExpression]] */ - private lazy val target2SourcePreCombineFiled: Option[(String, Expression)] = { - val updateActions = mergeInto.matchedActions.collect { case u: UpdateAction => u } - assert(updateActions.size <= 1, s"Only support one updateAction currently, current update action count is: ${updateActions.size}") - - val updateAction = updateActions.headOption - hoodieCatalogTable.preCombineKey.map(preCombineField => { - val sourcePreCombineField = - updateAction.map(u => u.assignments.filter { - case Assignment(key: AttributeReference, _) => key.name.equalsIgnoreCase(preCombineField) - case _=> false - }.head.value - ).getOrElse { - // If there is no update action, mapping the target column to the source by order. - val target2Source = mergeInto.targetTable.output - .filter(attr => !isMetaField(attr.name)) - .map(_.name) - .zip(mergeInto.sourceTable.output.filter(attr => !isMetaField(attr.name))) - .toMap - target2Source.getOrElse(preCombineField, null) + private lazy val preCombineAttributeAssociatedExpression: Option[(Attribute, Expression)] = { + val resolver = sparkSession.sessionState.analyzer.resolver + hoodieCatalogTable.preCombineKey.map { preCombineField => + val targetPreCombineAttribute = + mergeInto.targetTable.output + .find { attr => resolver(attr.name, preCombineField) } + .get + + // To find corresponding "pre-combine" attribute w/in the [[sourceTable]] we do + // - Check if we can resolve the attribute w/in the source table as is; if unsuccessful, then + // - Check if in any of the update actions, right-hand side of the assignment actually resolves + // to it, in which case we will determine left-hand side expression as the value of "pre-combine" + // attribute w/in the [[sourceTable]] + val sourceExpr = { + mergeInto.sourceTable.output.find(attr => resolver(attr.name, preCombineField)) match { + case Some(attr) => attr + case None => + updatingActions.flatMap(_.assignments).collectFirst { + case Assignment(attr: AttributeReference, expr) + if resolver(attr.name, preCombineField) && resolvesToSourceAttribute(expr) => expr + } getOrElse { + throw new AnalysisException(s"Failed to resolve pre-combine field `${preCombineField}` w/in the source-table output") + } + } - (preCombineField, sourcePreCombineField) - }).filter(p => p._2 != null) + } + + (targetPreCombineAttribute, sourceExpr) + } } override def run(sparkSession: SparkSession): Seq[Row] = { this.sparkSession = sparkSession + // TODO move to analysis phase + validate(mergeInto) + val sourceDF: DataFrame = sourceDataset // Create the write parameters - val parameters = buildMergeIntoConfig(hoodieCatalogTable) - // TODO Remove it when we implement ExpressionPayload for SparkRecord - val parametersWithAvroRecordMerger = parameters ++ Map(HoodieWriteConfig.RECORD_MERGER_IMPLS.key -> classOf[HoodieAvroRecordMerger].getName) - executeUpsert(sourceDF, parametersWithAvroRecordMerger) + val props = buildMergeIntoConfig(hoodieCatalogTable) + // Do the upsert + executeUpsert(sourceDF, props) + // Refresh the table in the catalog + sparkSession.catalog.refreshTable(hoodieCatalogTable.table.qualifiedName) - sparkSession.catalog.refreshTable(targetTableIdentify.unquotedString) Seq.empty[Row] } + private val updatingActions: Seq[UpdateAction] = mergeInto.matchedActions.collect { case u: UpdateAction => u} + private val insertingActions: Seq[InsertAction] = mergeInto.notMatchedActions.collect { case u: InsertAction => u} + private val deletingActions: Seq[DeleteAction] = mergeInto.matchedActions.collect { case u: DeleteAction => u} + /** - * Build the sourceDF. We will append the source primary key expressions and - * preCombine field expression to the sourceDF. - * e.g. - *

- * merge into h0 - * using (select 1 as id, 'a1' as name, 1000 as ts) s0 - * on h0.id = s0.id + 1 - * when matched then update set id = s0.id, name = s0.name, ts = s0.ts + 1 - *

- * "ts" is the pre-combine field of h0. + * Here we're adjusting incoming (source) dataset in case its schema is divergent from + * the target table, to make sure it (at a bare minimum) * - * The targetKey2SourceExpression is: ("id", "s0.id + 1"). - * The target2SourcePreCombineFiled is:("ts", "s0.ts + 1"). - * We will append the "s0.id + 1 as id" and "s0.ts + 1 as ts" to the sourceDF to compute the - * row key and pre-combine field. + *
    + *
  1. Contains "primary-key" column (as defined by target table's config)
  2. + *
  3. Contains "pre-combine" column (as defined by target table's config, if any)
  4. + *
* - */ - private lazy val sourceDF: DataFrame = { - var sourceDF = Dataset.ofRows(sparkSession, mergeInto.sourceTable) - targetKey2SourceExpression.foreach { - case (targetColumn, sourceExpression) - if !containsPrimaryKeyFieldReference(targetColumn, sourceExpression) => - sourceDF = sourceDF.withColumn(targetColumn, new Column(sourceExpression)) - sourceDFOutput = sourceDFOutput :+ AttributeReference(targetColumn, sourceExpression.dataType)() - case _=> - } - target2SourcePreCombineFiled.foreach { - case (targetPreCombineField, sourceExpression) - if !containsPreCombineFieldReference(targetPreCombineField, sourceExpression) => - sourceDF = sourceDF.withColumn(targetPreCombineField, new Column(sourceExpression)) - sourceDFOutput = sourceDFOutput :+ AttributeReference(targetPreCombineField, sourceExpression.dataType)() - case _=> - } - sourceDF - } - - /** - * Check whether the source expression has the same column name with target column. + * In cases when [[sourceTable]] doesn't contain aforementioned columns, following heuristic + * will be applied: + * + *
    + *
  • Expression for the "primary-key" column is extracted from the merge-on condition of the + * MIT statement: Hudi's implementation of the statement restricts kind of merge-on condition + * permitted to only such referencing primary-key column(s) of the target table; as such we're + * leveraging matching side of such conditional expression (containing [[sourceTable]] attrobute) + * interpreting it as a primary-key column in the [[sourceTable]]
  • + * + *
  • Expression for the "pre-combine" column (optional) is extracted from the matching update + * clause ({@code WHEN MATCHED ... THEN UPDATE ...}) as right-hand side of the expression referencing + * pre-combine attribute of the target column
  • + *
      + * + * For example, w/ the following statement (primary-key column is [[id]], while pre-combine column is [[ts]]) + *
      +   *    MERGE INTO target
      +   *    USING (SELECT 1 AS sid, 'A1' AS sname, 1000 AS sts) source
      +   *    ON target.id = source.sid
      +   *    WHEN MATCHED THEN UPDATE SET id = source.sid, name = source.sname, ts = source.sts
      +   * 
      * - * Merge condition cases that return true: - * 1) merge into .. on h0.id = s0.id .. - * 2) merge into .. on h0.id = cast(s0.id as int) .. - * "id" is primaryKey field of h0. + * We will append following columns to the source dataset: + *
        + *
      • {@code id = source.sid}
      • + *
      • {@code ts = source.sts}
      • + *
      */ - private def containsPrimaryKeyFieldReference(targetColumnName: String, sourceExpression: Expression): Boolean = { - val sourceColumnNames = sourceDFOutput.map(_.name) - val resolver = sparkSession.sessionState.conf.resolver - - sourceExpression match { - case attr: AttributeReference if sourceColumnNames.find(resolver(_, attr.name)).get.equals(targetColumnName) => true - // SPARK-35857: the definition of Cast has been changed in Spark3.2. - // Match the class type instead of call the `unapply` method. - case cast: Cast => - cast.child match { - case attr: AttributeReference if sourceColumnNames.find(resolver(_, attr.name)).get.equals(targetColumnName) => true - case _ => false - } - case _=> false + def sourceDataset: DataFrame = { + val resolver = sparkSession.sessionState.analyzer.resolver + + val sourceTablePlan = mergeInto.sourceTable + val sourceTableOutput = sourceTablePlan.output + + val requiredAttributesMap = primaryKeyAttributeToConditionExpression ++ preCombineAttributeAssociatedExpression + + val (existingAttributesMap, missingAttributesMap) = requiredAttributesMap.partition { + case (keyAttr, _) => sourceTableOutput.exists(attr => resolver(keyAttr.name, attr.name)) } - } - /** - * Check whether the source expression on preCombine field contains the same column name with target column. - * - * Merge expression cases that return true: - * 1) merge into .. on .. update set ts = s0.ts - * 2) merge into .. on .. update set ts = cast(s0.ts as int) - * 3) merge into .. on .. update set ts = s0.ts+1 (expressions like this whose sub node has the same column name with target) - * "ts" is preCombine field of h0. - */ - private def containsPreCombineFieldReference(targetColumnName: String, sourceExpression: Expression): Boolean = { - val sourceColumnNames = sourceDFOutput.map(_.name) - val resolver = sparkSession.sessionState.conf.resolver + // NOTE: Primary key attribute (required) as well as Pre-combine one (optional) defined + // in the [[targetTable]] schema has to be present in the incoming [[sourceTable]] dataset. + // In cases when [[sourceTable]] doesn't bear such attributes (which, for ex, could happen + // in case of it having different schema), we will be adding additional columns (while setting + // them according to aforementioned heuristic) to meet Hudi's requirements + val additionalColumns: Seq[NamedExpression] = + missingAttributesMap.flatMap { + case (keyAttr, sourceExpression) if !sourceTableOutput.exists(attr => resolver(attr.name, keyAttr.name)) => + Seq(Alias(sourceExpression, keyAttr.name)()) + + case _ => Seq() + } - // sub node of the expression may have same column name with target column name - sourceExpression.find { - case attr: AttributeReference => sourceColumnNames.find(resolver(_, attr.name)).get.equals(targetColumnName) - case _ => false - }.isDefined - } + // In case when we're not adding new columns we need to make sure that the casing of the key attributes' + // matches to that one of the target table. This is necessary b/c unlike Spark, Avro is case-sensitive + // and therefore would fail downstream if case of corresponding columns don't match + val existingAttributes = existingAttributesMap.map(_._1) + val adjustedSourceTableOutput = sourceTableOutput.map { attr => + existingAttributes.find(keyAttr => resolver(keyAttr.name, attr.name)) match { + // To align the casing we just rename the attribute to match that one of the + // target table + case Some(keyAttr) => attr.withName(keyAttr.name) + case _ => attr + } + } - /** - * Compare a [[Attribute]] to another, return true if they have the same column name(by resolver) and exprId - */ - private def attributeEqual( - attr: Attribute, other: Attribute, resolver: Resolver): Boolean = { - resolver(attr.name, other.name) && attr.exprId == other.exprId + val amendedPlan = Project(adjustedSourceTableOutput ++ additionalColumns, sourceTablePlan) + + Dataset.ofRows(sparkSession, amendedPlan) } /** @@ -308,190 +361,196 @@ case class MergeIntoHoodieTableCommand(mergeInto: MergeIntoTable) extends Hoodie (HoodieWriteConfig.WRITE_SCHEMA_OVERRIDE.key -> getTableSchema.toString) + (DataSourceWriteOptions.TABLE_TYPE.key -> targetTableType) - val updateActions = mergeInto.matchedActions.filter(_.isInstanceOf[UpdateAction]) - .map(_.asInstanceOf[UpdateAction]) - // Check for the update actions - checkUpdateAssignments(updateActions) - - val deleteActions = mergeInto.matchedActions.filter(_.isInstanceOf[DeleteAction]) - .map(_.asInstanceOf[DeleteAction]) - assert(deleteActions.size <= 1, "Should be only one delete action in the merge into statement.") - val deleteAction = deleteActions.headOption - - // Map of Condition -> Assignments - val updateConditionToAssignments = - updateActions.map(update => { - val rewriteCondition = update.condition.map(replaceAttributeInExpression) - .getOrElse(Literal.create(true, BooleanType)) - val formatAssignments = rewriteAndReOrderAssignments(update.assignments) - rewriteCondition -> formatAssignments - }).toMap - // Serialize the Map[UpdateCondition, UpdateAssignments] to base64 string - val serializedUpdateConditionAndExpressions = Base64.getEncoder - .encodeToString(Serializer.toBytes(updateConditionToAssignments)) - writeParams += (PAYLOAD_UPDATE_CONDITION_AND_ASSIGNMENTS -> - serializedUpdateConditionAndExpressions) - - if (deleteAction.isDefined) { - val deleteCondition = deleteAction.get.condition - .map(replaceAttributeInExpression) - .getOrElse(Literal.create(true, BooleanType)) - // Serialize the Map[DeleteCondition, empty] to base64 string - val serializedDeleteCondition = Base64.getEncoder - .encodeToString(Serializer.toBytes(Map(deleteCondition -> Seq.empty[Assignment]))) - writeParams += (PAYLOAD_DELETE_CONDITION -> serializedDeleteCondition) - } - - val insertActions = - mergeInto.notMatchedActions.map(_.asInstanceOf[InsertAction]) - - // Check for the insert actions - checkInsertAssignments(insertActions) - - // Serialize the Map[InsertCondition, InsertAssignments] to base64 string - writeParams += (PAYLOAD_INSERT_CONDITION_AND_ASSIGNMENTS -> - serializedInsertConditionAndExpressions(insertActions)) - - // Remove the meta fields from the sourceDF as we do not need these when writing. - val trimmedSourceDF = removeMetaFields(sourceDF) + writeParams ++= Seq( + // Append (encoded) updating actions + PAYLOAD_UPDATE_CONDITION_AND_ASSIGNMENTS -> + // NOTE: For updating clause we allow partial assignments, where only some of the fields of the target + // table's records are updated (w/ the missing ones keeping their existing values) + serializeConditionalAssignments(updatingActions.map(a => (a.condition, a.assignments)), + allowPartialAssignments = true), + // Append (encoded) inserting actions + PAYLOAD_INSERT_CONDITION_AND_ASSIGNMENTS -> + serializeConditionalAssignments(insertingActions.map(a => (a.condition, a.assignments)), + validator = validateInsertingAssignmentExpression) + ) - // Supply original record's Avro schema to provided to [[ExpressionPayload]] - writeParams += (PAYLOAD_RECORD_AVRO_SCHEMA -> - convertStructTypeToAvroSchema(trimmedSourceDF.schema, "record", "").toString) + // Append (encoded) deleting actions + writeParams ++= deletingActions.headOption.map { + case DeleteAction(condition) => + PAYLOAD_DELETE_CONDITION -> serializeConditionalAssignments(Seq(condition -> Seq.empty)) + }.toSeq + + // Append + // - Original [[sourceTable]] (Avro) schema + // - Schema of the expected "joined" output of the [[sourceTable]] and [[targetTable]] + writeParams ++= Seq( + PAYLOAD_RECORD_AVRO_SCHEMA -> + convertStructTypeToAvroSchema(sourceDF.schema, "record", "").toString, + PAYLOAD_EXPECTED_COMBINED_SCHEMA -> encodeAsBase64String(toStructType(joinedExpectedOutput)) + ) - val (success, _, _, _, _, _) = HoodieSparkSqlWriter.write(sparkSession.sqlContext, SaveMode.Append, writeParams, trimmedSourceDF) + val (success, _, _, _, _, _) = HoodieSparkSqlWriter.write(sparkSession.sqlContext, SaveMode.Append, writeParams, sourceDF) if (!success) { throw new HoodieException("Merge into Hoodie table command failed") } } - private def checkUpdateAssignments(updateActions: Seq[UpdateAction]): Unit = { - updateActions.foreach(update => - assert(update.assignments.length == targetTableSchemaWithoutMetaFields.length, - s"The number of update assignments[${update.assignments.length}] must equal to the " + - s"targetTable field size[${targetTableSchemaWithoutMetaFields.length}]")) - // For MOR table, the target table field cannot be the right-value in the update action. - if (targetTableType == MOR_TABLE_TYPE_OPT_VAL) { - updateActions.foreach(update => { - val targetAttrs = update.assignments.flatMap(a => a.value.collect { - case attr: AttributeReference if mergeInto.targetTable.outputSet.contains(attr) => attr - }) - assert(targetAttrs.isEmpty, - s"Target table's field(${targetAttrs.map(_.name).mkString(",")}) cannot be the right-value of the update clause for MOR table.") - }) - } - } - - private def checkInsertAssignments(insertActions: Seq[InsertAction]): Unit = { - insertActions.foreach(insert => - assert(insert.assignments.length == targetTableSchemaWithoutMetaFields.length, - s"The number of insert assignments[${insert.assignments.length}] must equal to the " + - s"targetTable field size[${targetTableSchemaWithoutMetaFields.length}]")) - - } - private def getTableSchema: Schema = { val (structName, nameSpace) = AvroConversionUtils - .getAvroRecordNameAndNamespace(targetTableIdentify.identifier) + .getAvroRecordNameAndNamespace(hoodieCatalogTable.tableName) AvroConversionUtils.convertStructTypeToAvroSchema( - new StructType(targetTableSchemaWithoutMetaFields), structName, nameSpace) + new StructType(targetTableSchema), structName, nameSpace) } /** - * Serialize the Map[InsertCondition, InsertAssignments] to base64 string. - * @param insertActions - * @return + * Binds and serializes sequence of [[(Expression, Seq[Expression])]] where + *
        + *
      • First [[Expression]] designates condition (in update/insert clause)
      • + *
      • Second [[Seq[Expression] ]] designates individual column assignments (in update/insert clause)
      • + *
      + * + * Such that + *
        + *
      1. All expressions are bound against expected payload layout (and ready to be code-gen'd)
      2. + *
      3. Serialized into Base64 string to be subsequently passed to [[ExpressionPayload]]
      4. + *
      */ - private def serializedInsertConditionAndExpressions(insertActions: Seq[InsertAction]): String = { - val insertConditionAndAssignments = - insertActions.map(insert => { - val rewriteCondition = insert.condition.map(replaceAttributeInExpression) - .getOrElse(Literal.create(true, BooleanType)) - val formatAssignments = rewriteAndReOrderAssignments(insert.assignments) - // Do the check for the insert assignments - checkInsertExpression(formatAssignments) - - rewriteCondition -> formatAssignments - }).toMap - Base64.getEncoder.encodeToString( - Serializer.toBytes(insertConditionAndAssignments)) + private def serializeConditionalAssignments(conditionalAssignments: Seq[(Option[Expression], Seq[Assignment])], + allowPartialAssignments: Boolean = false, + validator: Expression => Unit = scalaFunction1Noop): String = { + val boundConditionalAssignments = + conditionalAssignments.map { + case (condition, assignments) => + val boundCondition = condition.map(bindReferences).getOrElse(Literal.create(true, BooleanType)) + // NOTE: For deleting actions there's no assignments provided and no re-ordering is required. + // All other actions are expected to provide assignments correspondent to every field + // of the [[targetTable]] being assigned + val reorderedAssignments = if (assignments.nonEmpty) { + alignAssignments(assignments, allowPartialAssignments) + } else { + Seq.empty + } + // NOTE: We need to re-order assignments to follow the ordering of the attributes + // of the target table, such that the resulting output produced after execution + // of these expressions could be inserted into the target table as is + val boundAssignmentExprs = reorderedAssignments.map { + case Assignment(attr: Attribute, value) => + val boundExpr = bindReferences(value) + validator(boundExpr) + // Alias resulting expression w/ target table's expected column name, as well as + // do casting if necessary + Alias(castIfNeeded(boundExpr, attr.dataType), attr.name)() + } + + boundCondition -> boundAssignmentExprs + }.toMap + + encodeAsBase64String(boundConditionalAssignments) } /** - * Rewrite and ReOrder the assignments. - * The Rewrite is to replace the AttributeReference to BoundReference. - * The ReOrder is to make the assignments's order same with the target table. - * @param assignments - * @return + * Re-orders assignment expressions to adhere to the ordering of that of [[targetTable]] */ - private def rewriteAndReOrderAssignments(assignments: Seq[Expression]): Seq[Expression] = { - val attr2Assignment = assignments.map { - case Assignment(attr: AttributeReference, value) => { - val rewriteValue = replaceAttributeInExpression(value) - attr -> Alias(rewriteValue, attr.name)() - } - case assignment => throw new IllegalArgumentException(s"Illegal Assignment: ${assignment.sql}") - }.toMap[Attribute, Expression] - // reorder the assignments by the target table field + private def alignAssignments(assignments: Seq[Assignment], allowPartialAssignments: Boolean): Seq[Assignment] = { + val attr2Assignments = assignments.map { + case assign @ Assignment(attr: Attribute, _) => attr -> assign + case a => + throw new AnalysisException(s"Only assignments of the form `t.field = ...` are supported at the moment (provided: `${a.sql}`)") + } + + // Reorder the assignments to follow the ordering of the target table mergeInto.targetTable.output .filterNot(attr => isMetaField(attr.name)) - .map(attr => { - val assignment = attr2Assignment.find(f => attributeEqual(f._1, attr, sparkSession.sessionState.conf.resolver)) - .getOrElse(throw new IllegalArgumentException(s"Cannot find related assignment for field: ${attr.name}")) - castIfNeeded(assignment._2, attr.dataType, sparkSession.sqlContext.conf) - }) + .map { attr => + attr2Assignments.find(tuple => attributeEquals(tuple._1, attr)) match { + case Some((_, assignment)) => assignment + case None => + // In case partial assignments are allowed and there's no corresponding conditional assignment, + // create a self-assignment for the target table's attribute + if (allowPartialAssignments) { + Assignment(attr, attr) + } else { + throw new AnalysisException(s"Assignment expressions have to assign every attribute of target table " + + s"(provided: `${assignments.map(_.sql).mkString(",")}`") + } + } + } } /** - * Replace the AttributeReference to BoundReference. This is for the convenience of CodeGen - * in ExpressionCodeGen which use the field index to generate the code. So we must replace - * the AttributeReference to BoundReference here. - * @param exp - * @return + * Binds existing [[AttributeReference]]s (converting them into [[BoundReference]]s) against + * expected combined payload of + * + *
        + *
      1. Source table record, joined w/
      2. + *
      3. Target table record
      4. + *
      + * + * NOTE: PLEASE READ CAREFULLY BEFORE CHANGING + * This has to be in sync w/ [[ExpressionPayload]] that is actually performing comnbining of the + * records producing final payload being persisted. + * + * Joining is necessary to handle the case of the records being _updated_ (when record is present in + * both target and the source tables), since MIT statement allows resulting record to be + * an amalgamation of both existing and incoming records (for ex, partially updated). + * + * For newly inserted records, since no prior record exist in the target table, we're only going to + * use source payload to produce the resulting record -- hence, source dataset output is the left + * prefix of this join. + * + * Binding is necessary for [[ExpressionPayload]] to use the code-gen to effectively perform + * handling of the records (combining updated records, as well as producing new records to be inserted) */ - private def replaceAttributeInExpression(exp: Expression): Expression = { - val sourceJoinTargetFields = sourceDFOutput ++ - mergeInto.targetTable.output.filterNot(attr => isMetaField(attr.name)) - - exp transform { - case attr: AttributeReference => - val index = sourceJoinTargetFields.indexWhere(p => p.semanticEquals(attr)) - if (index == -1) { - throw new IllegalArgumentException(s"cannot find ${attr.qualifiedName} in source or " + - s"target at the merge into statement") - } - BoundReference(index, attr.dataType, attr.nullable) - case other => other - } + private def bindReferences(expr: Expression): Expression = { + // NOTE: Since original source dataset could be augmented w/ additional columns (please + // check its corresponding java-doc for more details) we have to get up-to-date list + // of its output attributes + val joinedExpectedOutputAttributes = joinedExpectedOutput + + bindReference(expr, joinedExpectedOutputAttributes, allowFailures = false) } /** - * Check the insert action expression. - * The insert expression should not contain target table field. + * Output of the expected (left) join of the a) [[sourceTable]] dataset (potentially amended w/ primary-key, + * pre-combine columns) with b) existing [[targetTable]] */ - private def checkInsertExpression(expressions: Seq[Expression]): Unit = { - expressions.foreach(exp => { - val references = exp.collect { - case reference: BoundReference => reference - } - references.foreach(ref => { - if (ref.ordinal >= sourceDFOutput.size) { - val targetColumn = targetTableSchemaWithoutMetaFields(ref.ordinal - sourceDFOutput.size) - throw new IllegalArgumentException(s"Insert clause cannot contain target table's field: ${targetColumn.name}" + - s" in ${exp.sql}") + private def joinedExpectedOutput: Seq[Attribute] = { + // NOTE: We're relying on [[sourceDataset]] here instead of [[mergeInto.sourceTable]], + // as it could be amended to add missing primary-key and/or pre-combine columns. + // Please check [[sourceDataset]] scala-doc for more details + sourceDataset.queryExecution.analyzed.output ++ mergeInto.targetTable.output + } + + private def resolvesToSourceAttribute(expr: Expression): Boolean = { + val sourceTableOutputSet = mergeInto.sourceTable.outputSet + expr match { + case attr: AttributeReference => sourceTableOutputSet.contains(attr) + case MatchCast(attr: AttributeReference, _, _, _) => sourceTableOutputSet.contains(attr) + + case _ => false + } + } + + private def validateInsertingAssignmentExpression(expr: Expression): Unit = { + val sourceTableOutput = mergeInto.sourceTable.output + expr.collect { case br: BoundReference => br } + .foreach(br => { + if (br.ordinal >= sourceTableOutput.length) { + throw new AnalysisException(s"Expressions in insert clause of the MERGE INTO statement can only reference " + + s"source table attributes (ordinal ${br.ordinal}, total attributes in the source table ${sourceTableOutput.length})") } }) - }) } /** * Create the config for hoodie writer. */ private def buildMergeIntoConfig(hoodieCatalogTable: HoodieCatalogTable): Map[String, String] = { - - val targetTableDb = targetTableIdentify.database.getOrElse("default") - val targetTableName = targetTableIdentify.identifier + val tableId = hoodieCatalogTable.table.identifier + val targetTableDb = tableId.database.getOrElse("default") + val targetTableName = tableId.identifier val path = hoodieCatalogTable.tableLocation + val catalogProperties = hoodieCatalogTable.catalogProperties val tableConfig = hoodieCatalogTable.tableConfig val tableSchema = hoodieCatalogTable.tableSchema val partitionColumns = tableConfig.getPartitionFieldProp.split(",").map(_.toLowerCase) @@ -524,8 +583,9 @@ case class MergeIntoHoodieTableCommand(mergeInto: MergeIntoTable) extends Hoodie HoodieSyncConfig.META_SYNC_PARTITION_EXTRACTOR_CLASS.key -> hiveSyncConfig.getString(HoodieSyncConfig.META_SYNC_PARTITION_EXTRACTOR_CLASS), SqlKeyGenerator.PARTITION_SCHEMA -> partitionSchema.toDDL, PAYLOAD_CLASS_NAME.key -> classOf[ExpressionPayload].getCanonicalName, + RECORD_MERGER_IMPLS.key -> classOf[HoodieAvroRecordMerger].getName, - // NOTE: We have to explicitly override following configs to make sure no schema validation is performed + // NOTE: We have to explicitly override following configs to make sure no schema validation is performed // as schema of the incoming dataset might be diverging from the table's schema (full schemas' // compatibility b/w table's schema and incoming one is not necessary in this case since we can // be cherry-picking only selected columns from the incoming dataset to be inserted/updated in the @@ -539,6 +599,49 @@ case class MergeIntoHoodieTableCommand(mergeInto: MergeIntoTable) extends Hoodie combineOptions(hoodieCatalogTable, tableConfig, sparkSession.sqlContext.conf, defaultOpts = Map.empty, overridingOpts = overridingOpts) } + + + def validate(mit: MergeIntoTable): Unit = { + checkUpdatingActions(updatingActions) + checkInsertingActions(insertingActions) + checkDeletingActions(deletingActions) + } + + private def checkDeletingActions(deletingActions: Seq[DeleteAction]): Unit = { + if (deletingActions.length > 1) { + throw new AnalysisException(s"Only one deleting action is supported in MERGE INTO statement (provided ${deletingActions.length})") + } + } + + private def checkInsertingActions(insertActions: Seq[InsertAction]): Unit = { + insertActions.foreach(insert => + assert(insert.assignments.length == targetTableSchema.length, + s"The number of insert assignments[${insert.assignments.length}] must equal to the " + + s"targetTable field size[${targetTableSchema.length}]")) + + } + + private def checkUpdatingActions(updateActions: Seq[UpdateAction]): Unit = { + if (updateActions.length > 1) { + throw new AnalysisException(s"Only one updating action is supported in MERGE INTO statement (provided ${updateActions.length})") + } + + //updateActions.foreach(update => + // assert(update.assignments.length == targetTableSchema.length, + // s"The number of update assignments[${update.assignments.length}] must equal to the " + + // s"targetTable field size[${targetTableSchema.length}]")) + + // For MOR table, the target table field cannot be the right-value in the update action. + if (targetTableType == MOR_TABLE_TYPE_OPT_VAL) { + updateActions.foreach(update => { + val targetAttrs = update.assignments.flatMap(a => a.value.collect { + case attr: AttributeReference if mergeInto.targetTable.outputSet.contains(attr) => attr + }) + assert(targetAttrs.isEmpty, + s"Target table's field(${targetAttrs.map(_.name).mkString(",")}) cannot be the right-value of the update clause for MOR table.") + }) + } + } } object MergeIntoHoodieTableCommand { @@ -554,4 +657,16 @@ object MergeIntoHoodieTableCommand { } } + def stripCasting(expr: EqualTo): EqualTo = expr match { + case EqualTo(MatchCast(leftExpr, leftTargetType, _, _), MatchCast(rightExpr, rightTargetType, _, _)) + if leftTargetType.sameType(rightTargetType) => EqualTo(leftExpr, rightExpr) + case _ => expr + } + + def toStructType(attrs: Seq[Attribute]): StructType = + StructType(attrs.map(a => StructField(a.qualifiedName.replace('.', '_'), a.dataType, a.nullable, a.metadata))) + + def encodeAsBase64String(any: Any): String = + Base64.getEncoder.encodeToString(Serializer.toBytes(any)) } + diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/UpdateHoodieTableCommand.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/UpdateHoodieTableCommand.scala index 277f2643423dd..3383e56600d20 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/UpdateHoodieTableCommand.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/UpdateHoodieTableCommand.scala @@ -18,67 +18,57 @@ package org.apache.spark.sql.hudi.command import org.apache.hudi.SparkAdapterSupport -import org.apache.hudi.common.model.HoodieRecord +import org.apache.spark.sql.HoodieCatalystExpressionUtils.attributeEquals import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog.HoodieCatalogTable -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Expression} -import org.apache.spark.sql.catalyst.plans.logical.{Assignment, UpdateTable} +import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference} +import org.apache.spark.sql.catalyst.plans.logical.{Assignment, Filter, Project, UpdateTable} import org.apache.spark.sql.hudi.HoodieSqlCommonUtils._ import org.apache.spark.sql.hudi.ProvidesHoodieConfig -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.StructField -import scala.collection.JavaConverters._ - -case class UpdateHoodieTableCommand(updateTable: UpdateTable) extends HoodieLeafRunnableCommand +case class UpdateHoodieTableCommand(ut: UpdateTable) extends HoodieLeafRunnableCommand with SparkAdapterSupport with ProvidesHoodieConfig { - private val table = updateTable.table - private val tableId = getTableIdentifier(table) - override def run(sparkSession: SparkSession): Seq[Row] = { - logInfo(s"start execute update command for $tableId") - val sqlConf = sparkSession.sessionState.conf - val name2UpdateValue = updateTable.assignments.map { - case Assignment(attr: AttributeReference, value) => - attr.name -> value - }.toMap + val catalogTable = sparkAdapter.resolveHoodieTable(ut.table) + .map(HoodieCatalogTable(sparkSession, _)) + .get + + val tableId = catalogTable.table.qualifiedName - val updateExpressions = table.output - .map(attr => { - val UpdateValueOption = name2UpdateValue.find(f => sparkSession.sessionState.conf.resolver(f._1, attr.name)) - if(UpdateValueOption.isEmpty) attr else UpdateValueOption.get._2 - }) - .filter { // filter the meta columns - case attr: AttributeReference => - !HoodieRecord.HOODIE_META_COLUMNS.asScala.toSet.contains(attr.name) - case _=> true - } + logInfo(s"Executing 'UPDATE' command for $tableId") - val projects = updateExpressions.zip(removeMetaFields(table.schema).fields).map { - case (attr: AttributeReference, field) => - Column(cast(attr, field, sqlConf)) - case (exp, field) => - Column(Alias(cast(exp, field, sqlConf), field.name)()) + val assignedAttributes = ut.assignments.map { + case Assignment(attr: AttributeReference, value) => attr -> value } - var df = Dataset.ofRows(sparkSession, table) - if (updateTable.condition.isDefined) { - df = df.filter(Column(updateTable.condition.get)) + val filteredOutput = removeMetaFields(ut.table.output) + val targetExprs = filteredOutput.map { targetAttr => + // NOTE: [[UpdateTable]] permits partial updates and therefore here we correlate assigned + // assigned attributes to the ones of the target table. Ones not being assigned + // will simply be carried over (from the old record) + assignedAttributes.find(p => attributeEquals(p._1, targetAttr)) + .map { case (_, expr) => Alias(castIfNeeded(expr, targetAttr.dataType), targetAttr.name)() } + .getOrElse(targetAttr) } - df = df.select(projects: _*) - val config = buildHoodieConfig(HoodieCatalogTable(sparkSession, tableId)) - df.write - .format("hudi") + + val condition = ut.condition.getOrElse(TrueLiteral) + val filteredPlan = Filter(condition, Project(targetExprs, ut.table)) + + val config = buildHoodieConfig(catalogTable) + val df = Dataset.ofRows(sparkSession, filteredPlan) + + df.write.format("hudi") .mode(SaveMode.Append) .options(config) .save() - sparkSession.catalog.refreshTable(tableId.unquotedString) - logInfo(s"Finish execute update command for $tableId") + + sparkSession.catalog.refreshTable(tableId) + + logInfo(s"Finished executing 'UPDATE' command for $tableId") + Seq.empty[Row] } - def cast(exp:Expression, field: StructField, sqlConf: SQLConf): Expression = { - castIfNeeded(exp, field.dataType, sqlConf) - } } diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/ExpressionPayload.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/ExpressionPayload.scala index 39a065f1df546..11f45c1138697 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/ExpressionPayload.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/payload/ExpressionPayload.scala @@ -20,10 +20,10 @@ package org.apache.spark.sql.hudi.command.payload import com.github.benmanes.caffeine.cache.{Cache, Caffeine} import org.apache.avro.Schema import org.apache.avro.generic.{GenericData, GenericRecord, IndexedRecord} -import org.apache.hudi.AvroConversionUtils.convertAvroSchemaToStructType +import org.apache.hudi.AvroConversionUtils.{convertAvroSchemaToStructType, convertStructTypeToAvroSchema} import org.apache.hudi.DataSourceWriteOptions._ import org.apache.hudi.SparkAdapterSupport.sparkAdapter -import org.apache.hudi.avro.AvroSchemaUtils.isNullable +import org.apache.hudi.avro.AvroSchemaUtils.{isNullable, resolveNullableSchema} import org.apache.hudi.avro.HoodieAvroUtils import org.apache.hudi.avro.HoodieAvroUtils.bytesToAvro import org.apache.hudi.common.model.BaseAvroPayload @@ -38,14 +38,13 @@ import org.apache.spark.sql.avro.{HoodieAvroDeserializer, HoodieAvroSerializer} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, Projection, SafeProjection} import org.apache.spark.sql.hudi.command.payload.ExpressionPayload._ -import org.apache.spark.sql.types.BooleanType +import org.apache.spark.sql.types.{BooleanType, StructType} import org.apache.spark.{SparkConf, SparkEnv} import java.nio.ByteBuffer import java.util.function.{Function, Supplier} -import java.util.{Base64, Properties} +import java.util.{Base64, Objects, Properties} import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer /** * A HoodieRecordPayload for MergeIntoHoodieTableCommand. @@ -83,7 +82,7 @@ class ExpressionPayload(@transient record: GenericRecord, val recordSchema = getRecordSchema(properties) val sourceRecord = bytesToAvro(recordBytes, recordSchema) - val joinedRecord = joinRecord(sourceRecord, targetRecord) + val joinedRecord = joinRecord(sourceRecord, targetRecord, properties) processMatchedRecord(ConvertibleRecord(joinedRecord), Some(targetRecord), properties) } @@ -272,24 +271,23 @@ class ExpressionPayload(@transient record: GenericRecord, /** * Join the source record with the target record. - * - * @return */ - private def joinRecord(sourceRecord: IndexedRecord, targetRecord: IndexedRecord): GenericRecord = { + private def joinRecord(sourceRecord: IndexedRecord, targetRecord: IndexedRecord, props: Properties): GenericRecord = { val leftSchema = sourceRecord.getSchema val joinSchema = getMergedSchema(leftSchema, targetRecord.getSchema) // TODO rebase onto JoinRecord - val values = new ArrayBuffer[AnyRef](joinSchema.getFields.size()) + val values = new Array[AnyRef](joinSchema.getFields.size()) for (i <- 0 until joinSchema.getFields.size()) { val value = if (i < leftSchema.getFields.size()) { sourceRecord.get(i) } else { // skip meta field targetRecord.get(i - leftSchema.getFields.size() + HoodieRecord.HOODIE_META_COLUMNS.size()) } - values += value + values(i) = value } - convertToRecord(values.toArray, joinSchema) + + convertToRecord(values, joinSchema) } } @@ -315,6 +313,18 @@ object ExpressionPayload { */ val PAYLOAD_RECORD_AVRO_SCHEMA = "hoodie.payload.record.schema" + /** + * Property associated w/ expected combined schema of the joined records of the source (incoming batch) + * and target (existing) tables + */ + val PAYLOAD_EXPECTED_COMBINED_SCHEMA = "hoodie.payload.combined.schema" + + /** + * Internal property determining whether combined schema should be validated by [[ExpressionPayload]], + * against the one provide by [[PAYLOAD_EXPECTED_COMBINED_SCHEMA]] (default is "false") + */ + private[sql] val PAYLOAD_SHOULD_VALIDATE_COMBINED_SCHEMA = "hoodie.payload.combined.schema.validate" + /** * NOTE: PLEASE READ CAREFULLY * Spark's [[SafeProjection]] are NOT thread-safe hence cache is scoped @@ -369,16 +379,31 @@ object ExpressionPayload { ) private val schemaCache = Caffeine.newBuilder() - .maximumSize(16).build[String, Schema]() + .maximumSize(16) + .build[String, AnyRef]() + + def getExpectedCombinedSchema(props: Properties): StructType = { + ValidationUtils.checkArgument(props.containsKey(PAYLOAD_EXPECTED_COMBINED_SCHEMA), + s"Missing ${PAYLOAD_EXPECTED_COMBINED_SCHEMA} property in the provided config") + + getCachedSchema(props.getProperty(PAYLOAD_EXPECTED_COMBINED_SCHEMA), + base64EncodedStructType => + Serializer.toObject(Base64.getDecoder.decode(base64EncodedStructType)).asInstanceOf[StructType]) + } + + private def getCachedSchema[T <: AnyRef](key: String, ctor: String => T): T = { + schemaCache.get(key, new Function[String, T] { + override def apply(key: String): T = { + ctor.apply(key) + } + }).asInstanceOf[T] + } private val mergedSchemaCache = Caffeine.newBuilder() .maximumSize(16).build[(Schema, Schema), Schema]() private def parseSchema(schemaStr: String): Schema = { - schemaCache.get(schemaStr, - new Function[String, Schema] { - override def apply(t: String): Schema = new Schema.Parser().parse(t) - }) + getCachedSchema(schemaStr, new Schema.Parser().parse(_)) } private def getRecordSchema(props: Properties) = { @@ -446,13 +471,44 @@ object ExpressionPayload { }) } + private def validateCompatibleSchemas(joinedSchema: Schema, expectedStructType: StructType, props: Properties): Unit = { + ValidationUtils.checkState(expectedStructType.fields.length == joinedSchema.getFields.size, + s"Expected schema diverges from the merged one: " + + s"expected has ${expectedStructType.fields.length} fields, while merged one has ${joinedSchema.getFields.size}") + + val shouldValidate = props.getProperty(PAYLOAD_SHOULD_VALIDATE_COMBINED_SCHEMA, "false").toBoolean + if (shouldValidate) { + val expectedSchema = convertStructTypeToAvroSchema(expectedStructType, joinedSchema.getName, joinedSchema.getNamespace) + // NOTE: Since compared schemas are produced by essentially combining (joining) + // 2 schemas together, field names might not be appropriate and therefore + // just structural compatibility will be checked (ie based on ordering of + // the fields as well as corresponding data-types) + expectedSchema.getFields.asScala + .zip(joinedSchema.getFields.asScala) + .zipWithIndex + .foreach { + case ((expectedField, targetField), idx) => + val expectedFieldSchema = resolveNullableSchema(expectedField.schema()) + val targetFieldSchema = resolveNullableSchema(targetField.schema()) + + val equal = Objects.equals(expectedFieldSchema, targetFieldSchema) + ValidationUtils.checkState(equal, + s""" + |Expected schema diverges from the target one in #$idx field: + |Expected data-type: $expectedFieldSchema + |Received data-type: $targetFieldSchema + |""".stripMargin) + } + } + } + private def mergeSchema(a: Schema, b: Schema): Schema = { val mergedFields = a.getFields.asScala.map(field => - new Schema.Field("a_" + field.name, + new Schema.Field("source_" + field.name, field.schema, field.doc, field.defaultVal, field.order)) ++ b.getFields.asScala.map(field => - new Schema.Field("b_" + field.name, + new Schema.Field("target_" + field.name, field.schema, field.doc, field.defaultVal, field.order)) Schema.createRecord(a.getName, a.getDoc, a.getNamespace, a.isError, mergedFields.asJava) } diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/parser/HoodieCommonSqlParser.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/parser/HoodieCommonSqlParser.scala index 8ce8c61938761..6f78423fd10e0 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/parser/HoodieCommonSqlParser.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/parser/HoodieCommonSqlParser.scala @@ -35,13 +35,12 @@ class HoodieCommonSqlParser(session: SparkSession, delegate: ParserInterface) extends ParserInterface with Logging with SparkAdapterSupport { private lazy val builder = new HoodieSqlCommonAstBuilder(session, delegate) - private lazy val sparkExtendedParser = sparkAdapter.createExtendedSparkParser - .map(_(session, delegate)).getOrElse(delegate) + private lazy val sparkExtendedParser = sparkAdapter.createExtendedSparkParser(session, delegate) override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser => builder.visit(parser.singleStatement()) match { case plan: LogicalPlan => plan - case _=> sparkExtendedParser.parsePlan(sqlText) + case _ => sparkExtendedParser.parsePlan(sqlText) } } @@ -57,21 +56,21 @@ class HoodieCommonSqlParser(session: SparkSession, delegate: ParserInterface) override def parseDataType(sqlText: String): DataType = delegate.parseDataType(sqlText) - /* SPARK-37266 Added parseQuery to ParserInterface in Spark 3.3.0. This is a patch to prevent - hackers from tampering text with persistent view, it won't be called in older Spark - Don't mark this as override for backward compatibility - Can't use sparkExtendedParser directly here due to the same reason */ - def parseQuery(sqlText: String): LogicalPlan = parse(sqlText) { parser => - sparkAdapter.getQueryParserFromExtendedSqlParser(session, delegate, sqlText) - } + /** + * SPARK-37266 Added [[parseQuery]] to [[ParserInterface]] in Spark 3.3.0. + * Don't mark this as override for backward compatibility + */ + def parseQuery(sqlText: String): LogicalPlan = sparkExtendedParser.parseQuery(sqlText) def parseRawDataType(sqlText : String) : DataType = { throw new UnsupportedOperationException(s"Unsupported parseRawDataType method") } - def parseMultipartIdentifier(sqlText: String): Seq[String] = { - sparkAdapter.parseMultipartIdentifier(delegate, sqlText) - } + /** + * Added [[parseMultipartIdentifier]] to [[ParserInterface]] in Spark 3.0.0. + * Don't mark this as override for backward compatibility + */ + def parseMultipartIdentifier(sqlText: String): Seq[String] = sparkExtendedParser.parseMultipartIdentifier(sqlText) protected def parse[T](command: String)(toResult: HoodieSqlCommonParser => T): T = { logDebug(s"Parsing command: $command") diff --git a/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/functional/TestBootstrap.java b/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/functional/TestBootstrap.java index 22d2e934d536c..be93171adc2d2 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/functional/TestBootstrap.java +++ b/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/functional/TestBootstrap.java @@ -58,7 +58,7 @@ import org.apache.hudi.keygen.NonpartitionedKeyGenerator; import org.apache.hudi.keygen.SimpleKeyGenerator; import org.apache.hudi.table.action.bootstrap.BootstrapUtils; -import org.apache.hudi.testutils.HoodieClientTestBase; +import org.apache.hudi.testutils.HoodieSparkClientTestBase; import org.apache.hudi.testutils.HoodieMergeOnReadTestUtils; import org.apache.avro.Schema; @@ -114,7 +114,7 @@ * Tests Bootstrap Client functionality. */ @Tag("functional") -public class TestBootstrap extends HoodieClientTestBase { +public class TestBootstrap extends HoodieSparkClientTestBase { public static final String TRIP_HIVE_COLUMN_TYPES = "bigint,string,string,string,string,double,double,double,double," + "struct,array>,boolean"; diff --git a/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/functional/TestHoodieDatasetBulkInsertHelper.java b/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/functional/TestHoodieDatasetBulkInsertHelper.java index 373d187ea66bb..17c6f23089e07 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/functional/TestHoodieDatasetBulkInsertHelper.java +++ b/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/functional/TestHoodieDatasetBulkInsertHelper.java @@ -30,7 +30,7 @@ import org.apache.hudi.keygen.SimpleKeyGenerator; import org.apache.hudi.metadata.HoodieTableMetadata; import org.apache.hudi.testutils.DataSourceTestUtils; -import org.apache.hudi.testutils.HoodieClientTestBase; +import org.apache.hudi.testutils.HoodieSparkClientTestBase; import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.api.java.function.ReduceFunction; import org.apache.spark.sql.Dataset; @@ -67,7 +67,7 @@ * Tests {@link HoodieDatasetBulkInsertHelper}. */ @Tag("functional") -public class TestHoodieDatasetBulkInsertHelper extends HoodieClientTestBase { +public class TestHoodieDatasetBulkInsertHelper extends HoodieSparkClientTestBase { private String schemaStr; private transient Schema schema; diff --git a/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/functional/TestOrcBootstrap.java b/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/functional/TestOrcBootstrap.java index cc9b4cefcb6e2..77a980f01e548 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/functional/TestOrcBootstrap.java +++ b/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/functional/TestOrcBootstrap.java @@ -55,7 +55,7 @@ import org.apache.hudi.keygen.NonpartitionedKeyGenerator; import org.apache.hudi.keygen.SimpleKeyGenerator; import org.apache.hudi.table.action.bootstrap.BootstrapUtils; -import org.apache.hudi.testutils.HoodieClientTestBase; +import org.apache.hudi.testutils.HoodieSparkClientTestBase; import org.apache.avro.Schema; import org.apache.avro.generic.GenericRecord; @@ -106,7 +106,7 @@ * Tests Bootstrap Client functionality. */ @Tag("functional") -public class TestOrcBootstrap extends HoodieClientTestBase { +public class TestOrcBootstrap extends HoodieSparkClientTestBase { @TempDir public java.nio.file.Path tmpFolder; diff --git a/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/keygen/TestTimestampBasedKeyGenerator.java b/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/keygen/TestTimestampBasedKeyGenerator.java index 8cfd7b04507cf..e71fce30de443 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/keygen/TestTimestampBasedKeyGenerator.java +++ b/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/keygen/TestTimestampBasedKeyGenerator.java @@ -24,6 +24,7 @@ import org.apache.avro.generic.GenericFixed; import org.apache.avro.generic.GenericRecord; import org.apache.hudi.AvroConversionUtils; +import org.apache.hudi.avro.AvroSchemaUtils; import org.apache.hudi.common.config.TypedProperties; import org.apache.hudi.common.model.HoodieKey; import org.apache.hudi.common.testutils.SchemaTestUtil; @@ -39,7 +40,6 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import scala.Function1; -import scala.Tuple2; import java.io.IOException; import java.math.BigDecimal; @@ -144,8 +144,8 @@ public void testTimestampBasedKeyGenerator() throws IOException { // timezone is GMT+8:00, createTime is BigDecimal BigDecimal decimal = new BigDecimal("1578283932000.0001"); Conversions.DecimalConversion conversion = new Conversions.DecimalConversion(); - Tuple2 resolvedNullableSchema = AvroConversionUtils.resolveAvroTypeNullability(schema.getField("createTimeDecimal").schema()); - GenericFixed avroDecimal = conversion.toFixed(decimal, resolvedNullableSchema._2, LogicalTypes.decimal(20, 4)); + Schema resolvedNullableSchema = AvroSchemaUtils.resolveNullableSchema(schema.getField("createTimeDecimal").schema()); + GenericFixed avroDecimal = conversion.toFixed(decimal, resolvedNullableSchema, LogicalTypes.decimal(20, 4)); baseRecord.put("createTimeDecimal", avroDecimal); properties = getBaseKeyConfig("createTimeDecimal", "EPOCHMILLISECONDS", "yyyy-MM-dd hh", "GMT+8:00", null); keyGen = new TimestampBasedKeyGenerator(properties); diff --git a/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/testutils/HoodieSparkClientTestBase.java b/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/testutils/HoodieSparkClientTestBase.java new file mode 100644 index 0000000000000..ed96df17544c5 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark/src/test/java/org/apache/hudi/testutils/HoodieSparkClientTestBase.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hudi.testutils; + +import org.apache.hudi.common.util.Option; +import org.apache.spark.sql.SparkSessionExtensions; +import org.apache.spark.sql.hudi.HoodieSparkSessionExtension; + +import java.util.function.Consumer; + +public abstract class HoodieSparkClientTestBase extends HoodieClientTestBase { + + @Override + protected Option> getSparkSessionExtensionsInjector() { + return Option.of((receiver) -> new HoodieSparkSessionExtension().apply(receiver)); + } + +} diff --git a/hudi-spark-datasource/hudi-spark/src/test/resources/sql-statements.sql b/hudi-spark-datasource/hudi-spark/src/test/resources/sql-statements.sql index 449ba2e2e67b0..0259f4530398e 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/resources/sql-statements.sql +++ b/hudi-spark-datasource/hudi-spark/src/test/resources/sql-statements.sql @@ -208,7 +208,7 @@ using ( select 5 as _id, 'a5' as _name, 10 as _price, 1000 as _ts, '2021-05-08' as dt ) s0 on s0._id = t0.id -when matched then update set * +when matched then update set id = _id, name = _name, price = _price, ts = _ts, dt = s0.dt when not matched then insert (id, name, price, ts, dt) values(_id, _name, _price, _ts, s0.dt); +----------+ | ok | @@ -231,10 +231,10 @@ using ( select 6 as id, '_insert' as name, 10 as price, 1000 as ts, '2021-05-08' as dt ) s0 on s0.id = t0.id -when matched and name = '_update' +when matched and s0.name = '_update' then update set id = s0.id, name = s0.name, price = s0.price, ts = s0.ts, dt = s0.dt -when matched and name = '_delete' then delete -when not matched and name = '_insert' then insert *; +when matched and s0.name = '_delete' then delete +when not matched and s0.name = '_insert' then insert *; +----------+ | ok | +----------+ diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestDataSkippingUtils.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestDataSkippingUtils.scala index f995e484d4912..f60b95d8f5aa1 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestDataSkippingUtils.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestDataSkippingUtils.scala @@ -18,7 +18,7 @@ package org.apache.hudi import org.apache.hudi.ColumnStatsIndexSupport.composeIndexSchema -import org.apache.hudi.testutils.HoodieClientTestBase +import org.apache.hudi.testutils.HoodieSparkClientTestBase import org.apache.spark.sql.HoodieCatalystExpressionUtils.resolveExpr import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.encoders.DummyExpressionHolder @@ -62,7 +62,7 @@ case class IndexRow(fileName: String, def toRow: Row = Row(productIterator.toSeq: _*) } -class TestDataSkippingUtils extends HoodieClientTestBase with SparkAdapterSupport { +class TestDataSkippingUtils extends HoodieSparkClientTestBase with SparkAdapterSupport { var spark: SparkSession = _ diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestGenericRecordAndRowConsistency.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestGenericRecordAndRowConsistency.scala index a3e4a8c8302c5..9a557a343ef06 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestGenericRecordAndRowConsistency.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestGenericRecordAndRowConsistency.scala @@ -18,14 +18,14 @@ package org.apache.hudi import org.apache.hudi.config.HoodieWriteConfig -import org.apache.hudi.testutils.HoodieClientTestBase +import org.apache.hudi.testutils.HoodieSparkClientTestBase import org.apache.spark.sql.{DataFrame, SparkSession} import org.junit.jupiter.api.Assertions.{assertArrayEquals, assertEquals} import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} import java.sql.{Date, Timestamp} -class TestGenericRecordAndRowConsistency extends HoodieClientTestBase { +class TestGenericRecordAndRowConsistency extends HoodieSparkClientTestBase { var spark: SparkSession = _ val commonOpts = Map( diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala index 2782159b2a0d6..70e7e3a1544eb 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieFileIndex.scala @@ -41,6 +41,7 @@ import org.apache.hudi.exception.HoodieException import org.apache.hudi.keygen.ComplexKeyGenerator import org.apache.hudi.keygen.TimestampBasedAvroKeyGenerator.TimestampType import org.apache.hudi.keygen.constant.KeyGeneratorOptions.Config +import org.apache.hudi.testutils.HoodieSparkClientTestBase import org.apache.hudi.metadata.HoodieTableMetadata import org.apache.hudi.testutils.HoodieClientTestBase import org.apache.hudi.util.JFunction @@ -61,7 +62,7 @@ import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ import scala.util.Random -class TestHoodieFileIndex extends HoodieClientTestBase with ScalaAssertionSupport { +class TestHoodieFileIndex extends HoodieSparkClientTestBase with ScalaAssertionSupport { var spark: SparkSession = _ val commonOpts = Map( diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestBasicSchemaEvolution.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestBasicSchemaEvolution.scala index ccbd04a45b68a..b5d1e61b7aa30 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestBasicSchemaEvolution.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestBasicSchemaEvolution.scala @@ -25,7 +25,7 @@ import org.apache.hudi.common.util import org.apache.hudi.config.HoodieWriteConfig import org.apache.hudi.exception.SchemaCompatibilityException import org.apache.hudi.functional.TestBasicSchemaEvolution.{dropColumn, injectColumnAt} -import org.apache.hudi.testutils.HoodieClientTestBase +import org.apache.hudi.testutils.HoodieSparkClientTestBase import org.apache.hudi.util.JFunction import org.apache.hudi.{AvroConversionUtils, DataSourceWriteOptions, ScalaAssertionSupport} import org.apache.spark.sql.hudi.HoodieSparkSessionExtension @@ -40,7 +40,7 @@ import java.util.function.Consumer import scala.collection.JavaConversions.asScalaBuffer import scala.collection.JavaConverters._ -class TestBasicSchemaEvolution extends HoodieClientTestBase with ScalaAssertionSupport { +class TestBasicSchemaEvolution extends HoodieSparkClientTestBase with ScalaAssertionSupport { var spark: SparkSession = null val commonOpts = Map( diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala index ea34123be437f..76e8c26b7bd43 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestCOWDataSource.scala @@ -36,7 +36,7 @@ import org.apache.hudi.functional.TestCOWDataSource.convertColumnsToNullable import org.apache.hudi.keygen._ import org.apache.hudi.keygen.constant.KeyGeneratorOptions.Config import org.apache.hudi.metrics.Metrics -import org.apache.hudi.testutils.HoodieClientTestBase +import org.apache.hudi.testutils.HoodieSparkClientTestBase import org.apache.hudi.util.JFunction import org.apache.hudi.{AvroConversionUtils, DataSourceReadOptions, DataSourceWriteOptions, HoodieDataSourceHelpers, HoodieSparkRecordMerger, QuickstartUtils, ScalaAssertionSupport} import org.apache.spark.sql._ @@ -46,7 +46,7 @@ import org.apache.spark.sql.types._ import org.joda.time.DateTime import org.joda.time.format.DateTimeFormat import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue, fail} -import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} +import org.junit.jupiter.api.{AfterEach, BeforeEach, Disabled, Test} import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.{CsvSource, EnumSource, ValueSource} @@ -59,7 +59,7 @@ import scala.collection.JavaConverters._ /** * Basic tests on the spark datasource for COW table. */ -class TestCOWDataSource extends HoodieClientTestBase with ScalaAssertionSupport { +class TestCOWDataSource extends HoodieSparkClientTestBase with ScalaAssertionSupport { var spark: SparkSession = null val commonOpts = Map( "hoodie.insert.shuffle.parallelism" -> "4", @@ -962,7 +962,6 @@ class TestCOWDataSource extends HoodieClientTestBase with ScalaAssertionSupport @EnumSource(value = classOf[HoodieRecordType], names = Array("AVRO", "SPARK")) def testWriteSmallPrecisionDecimalTable(recordType: HoodieRecordType): Unit = { val (writeOpts, readOpts) = getWriterReaderOpts(recordType) - val records1 = recordsToStrings(dataGen.generateInserts("001", 5)).toList val inputDF1 = spark.read.json(spark.sparkContext.parallelize(records1, 2)) .withColumn("shortDecimal", lit(new java.math.BigDecimal(s"2090.0000"))) // create decimalType(8, 4) diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestColumnStatsIndex.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestColumnStatsIndex.scala index 289d641e2e812..5eb785290807a 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestColumnStatsIndex.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestColumnStatsIndex.scala @@ -29,7 +29,7 @@ import org.apache.hudi.common.table.{HoodieTableConfig, HoodieTableMetaClient} import org.apache.hudi.common.util.ParquetUtils import org.apache.hudi.config.HoodieWriteConfig import org.apache.hudi.functional.TestColumnStatsIndex.ColumnStatsTestCase -import org.apache.hudi.testutils.HoodieClientTestBase +import org.apache.hudi.testutils.HoodieSparkClientTestBase import org.apache.hudi.{ColumnStatsIndexSupport, DataSourceWriteOptions} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute @@ -48,7 +48,7 @@ import scala.collection.JavaConverters._ import scala.util.Random @Tag("functional") -class TestColumnStatsIndex extends HoodieClientTestBase { +class TestColumnStatsIndex extends HoodieSparkClientTestBase { var spark: SparkSession = _ val sourceTableSchema = diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestEmptyCommit.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestEmptyCommit.scala index addd23ef82e6a..eea719203f7ca 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestEmptyCommit.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestEmptyCommit.scala @@ -18,7 +18,7 @@ package org.apache.hudi.functional import org.apache.hudi.config.HoodieWriteConfig -import org.apache.hudi.testutils.HoodieClientTestBase +import org.apache.hudi.testutils.HoodieSparkClientTestBase import org.apache.hudi.{DataSourceWriteOptions, HoodieDataSourceHelpers} import org.apache.spark.sql.{SaveMode, SparkSession} import org.junit.jupiter.api.Assertions.assertEquals @@ -26,7 +26,7 @@ import org.junit.jupiter.api.{AfterEach, BeforeEach} import org.junit.jupiter.params.ParameterizedTest import org.junit.jupiter.params.provider.ValueSource -class TestEmptyCommit extends HoodieClientTestBase { +class TestEmptyCommit extends HoodieSparkClientTestBase { var spark: SparkSession = _ val commonOpts = Map( "hoodie.insert.shuffle.parallelism" -> "4", diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestHoodieActiveTimeline.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestHoodieActiveTimeline.scala index 03bd2fe0776b5..7b2585dea6a4d 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestHoodieActiveTimeline.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestHoodieActiveTimeline.scala @@ -21,7 +21,7 @@ import org.apache.hudi.common.model.HoodieFileFormat import org.apache.hudi.common.table.HoodieTableMetaClient import org.apache.hudi.common.testutils.RawTripTestPayload.recordsToStrings import org.apache.hudi.config.HoodieWriteConfig -import org.apache.hudi.testutils.HoodieClientTestBase +import org.apache.hudi.testutils.HoodieSparkClientTestBase import org.apache.hudi.{DataSourceWriteOptions, HoodieDataSourceHelpers} import org.apache.log4j.LogManager @@ -36,7 +36,7 @@ import scala.collection.JavaConversions._ /** * Tests on HoodieActionTimeLine using the real hudi table. */ -class TestHoodieActiveTimeline extends HoodieClientTestBase { +class TestHoodieActiveTimeline extends HoodieSparkClientTestBase { var spark: SparkSession = null diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestIncrementalReadWithFullTableScan.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestIncrementalReadWithFullTableScan.scala index b828a0626bb8b..ffb8b020f826a 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestIncrementalReadWithFullTableScan.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestIncrementalReadWithFullTableScan.scala @@ -25,7 +25,7 @@ import org.apache.hudi.common.table.timeline.{HoodieInstant, HoodieInstantTimeGe import org.apache.hudi.common.table.timeline.HoodieTimeline.GREATER_THAN import org.apache.hudi.common.testutils.RawTripTestPayload.recordsToStrings import org.apache.hudi.config.HoodieWriteConfig -import org.apache.hudi.testutils.HoodieClientTestBase +import org.apache.hudi.testutils.HoodieSparkClientTestBase import org.apache.log4j.LogManager import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} @@ -37,7 +37,7 @@ import org.junit.jupiter.params.provider.EnumSource import scala.collection.JavaConversions.asScalaBuffer -class TestIncrementalReadWithFullTableScan extends HoodieClientTestBase { +class TestIncrementalReadWithFullTableScan extends HoodieSparkClientTestBase { var spark: SparkSession = null private val log = LogManager.getLogger(classOf[TestIncrementalReadWithFullTableScan]) diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestLayoutOptimization.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestLayoutOptimization.scala index b87c813a0fa12..6400468da8173 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestLayoutOptimization.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestLayoutOptimization.scala @@ -24,7 +24,7 @@ import org.apache.hudi.common.table.HoodieTableMetaClient import org.apache.hudi.common.table.timeline.{HoodieInstant, HoodieTimeline} import org.apache.hudi.common.testutils.RawTripTestPayload.recordsToStrings import org.apache.hudi.config.{HoodieClusteringConfig, HoodieWriteConfig} -import org.apache.hudi.testutils.HoodieClientTestBase +import org.apache.hudi.testutils.HoodieSparkClientTestBase import org.apache.hudi.{DataSourceReadOptions, DataSourceWriteOptions} import org.apache.spark.sql._ import org.apache.spark.sql.types._ @@ -37,7 +37,7 @@ import org.junit.jupiter.params.provider.{Arguments, MethodSource} import scala.collection.JavaConversions._ @Tag("functional") -class TestLayoutOptimization extends HoodieClientTestBase { +class TestLayoutOptimization extends HoodieSparkClientTestBase { var spark: SparkSession = _ val sourceTableSchema = diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSource.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSource.scala index de3ace23fb8b4..354d85e168008 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSource.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSource.scala @@ -35,7 +35,7 @@ import org.apache.hudi.index.HoodieIndex.IndexType import org.apache.hudi.keygen.NonpartitionedKeyGenerator import org.apache.hudi.keygen.constant.KeyGeneratorOptions.Config import org.apache.hudi.table.action.compact.CompactionTriggerStrategy -import org.apache.hudi.testutils.{DataSourceTestUtils, HoodieClientTestBase} +import org.apache.hudi.testutils.{DataSourceTestUtils, HoodieSparkClientTestBase} import org.apache.hudi.util.JFunction import org.apache.hudi.{DataSourceReadOptions, DataSourceUtils, DataSourceWriteOptions, HoodieDataSourceHelpers, HoodieSparkRecordMerger, SparkDatasetMixin} import org.apache.log4j.LogManager @@ -56,7 +56,7 @@ import scala.collection.JavaConverters._ /** * Tests on Spark DataSource for MOR table. */ -class TestMORDataSource extends HoodieClientTestBase with SparkDatasetMixin { +class TestMORDataSource extends HoodieSparkClientTestBase with SparkDatasetMixin { var spark: SparkSession = null private val log = LogManager.getLogger(classOf[TestMORDataSource]) diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSourceWithBucketIndex.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSourceWithBucketIndex.scala index 187de2d8e0671..41913074d54ed 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSourceWithBucketIndex.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestMORDataSourceWithBucketIndex.scala @@ -24,7 +24,7 @@ import org.apache.hudi.index.HoodieIndex.IndexType import org.apache.hudi.keygen.constant.KeyGeneratorOptions import org.apache.hudi.table.action.commit.SparkBucketIndexPartitioner import org.apache.hudi.table.storage.HoodieStorageLayout -import org.apache.hudi.testutils.HoodieClientTestBase +import org.apache.hudi.testutils.HoodieSparkClientTestBase import org.apache.hudi.{DataSourceReadOptions, DataSourceWriteOptions, HoodieDataSourceHelpers} import org.apache.spark.sql._ import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue} @@ -35,7 +35,7 @@ import scala.collection.JavaConversions._ /** * */ -class TestMORDataSourceWithBucketIndex extends HoodieClientTestBase { +class TestMORDataSourceWithBucketIndex extends HoodieSparkClientTestBase { var spark: SparkSession = null val commonOpts = Map( diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestStructuredStreaming.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestStructuredStreaming.scala index 1e3356a958350..9487e54ea6082 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestStructuredStreaming.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestStructuredStreaming.scala @@ -30,7 +30,7 @@ import org.apache.hudi.common.testutils.{HoodieTestDataGenerator, HoodieTestTabl import org.apache.hudi.common.util.{CollectionUtils, CommitUtils} import org.apache.hudi.config.{HoodieClusteringConfig, HoodieCompactionConfig, HoodieLockConfig, HoodieWriteConfig} import org.apache.hudi.exception.TableNotFoundException -import org.apache.hudi.testutils.HoodieClientTestBase +import org.apache.hudi.testutils.HoodieSparkClientTestBase import org.apache.hudi.{DataSourceReadOptions, DataSourceWriteOptions, HoodieDataSourceHelpers, HoodieSinkCheckpoint} import org.apache.log4j.LogManager import org.apache.spark.sql._ @@ -49,7 +49,7 @@ import scala.concurrent.{Await, Future} /** * Basic tests on the spark datasource for structured streaming sink */ -class TestStructuredStreaming extends HoodieClientTestBase { +class TestStructuredStreaming extends HoodieSparkClientTestBase { private val log = LogManager.getLogger(getClass) var spark: SparkSession = _ diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestTimeTravelQuery.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestTimeTravelQuery.scala index 5a71f0e371360..abdf57f99f57f 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestTimeTravelQuery.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/TestTimeTravelQuery.scala @@ -23,7 +23,7 @@ import org.apache.hudi.common.table.{HoodieTableMetaClient, TableSchemaResolver} import org.apache.hudi.common.table.timeline.HoodieActiveTimeline import org.apache.hudi.config.HoodieWriteConfig import org.apache.hudi.keygen.{ComplexKeyGenerator, NonpartitionedKeyGenerator} -import org.apache.hudi.testutils.HoodieClientTestBase +import org.apache.hudi.testutils.HoodieSparkClientTestBase import org.apache.hudi.{DataSourceReadOptions, DataSourceWriteOptions} import org.apache.spark.sql.{Row, SaveMode, SparkSession} @@ -35,7 +35,7 @@ import org.junit.jupiter.params.provider.EnumSource import java.text.SimpleDateFormat -class TestTimeTravelQuery extends HoodieClientTestBase { +class TestTimeTravelQuery extends HoodieSparkClientTestBase { var spark: SparkSession =_ val commonOpts = Map( "hoodie.insert.shuffle.parallelism" -> "4", diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/cdc/HoodieCDCTestBase.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/cdc/HoodieCDCTestBase.scala index fce3f2289e691..88174b5d5f1ef 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/cdc/HoodieCDCTestBase.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/functional/cdc/HoodieCDCTestBase.scala @@ -28,7 +28,7 @@ import org.apache.hudi.common.table.log.block.HoodieDataBlock import org.apache.hudi.common.table.timeline.HoodieInstant import org.apache.hudi.common.testutils.RawTripTestPayload import org.apache.hudi.config.{HoodieCleanConfig, HoodieWriteConfig} -import org.apache.hudi.testutils.HoodieClientTestBase +import org.apache.hudi.testutils.HoodieSparkClientTestBase import org.apache.avro.Schema import org.apache.avro.generic.{GenericRecord, IndexedRecord} import org.apache.hadoop.fs.Path @@ -41,7 +41,7 @@ import org.junit.jupiter.api.Assertions.{assertEquals, assertNotEquals, assertNu import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ -abstract class HoodieCDCTestBase extends HoodieClientTestBase { +abstract class HoodieCDCTestBase extends HoodieSparkClientTestBase { var spark: SparkSession = _ diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/HoodieSparkSqlTestBase.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/HoodieSparkSqlTestBase.scala index 7077b2d37a33e..94fc3bb9d5f55 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/HoodieSparkSqlTestBase.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/HoodieSparkSqlTestBase.scala @@ -50,6 +50,9 @@ class HoodieSparkSqlTestBase extends FunSuite with BeforeAndAfterAll { dir } + // NOTE: We need to set "spark.testing" property to make sure Spark can appropriately + // recognize environment as testing + System.setProperty("spark.testing", "true") // NOTE: We have to fix the timezone to make sure all date-/timestamp-bound utilities output // is consistent with the fixtures DateTimeZone.setDefault(DateTimeZone.UTC) diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable.scala index 99fd72d5a280f..6a203718fd1e2 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable.scala @@ -20,11 +20,13 @@ package org.apache.spark.sql.hudi import org.apache.hudi.{DataSourceReadOptions, HoodieDataSourceHelpers, HoodieSparkUtils, ScalaAssertionSupport} import org.apache.hudi.common.fs.FSUtils import org.apache.hudi.exception.SchemaCompatibilityException +import org.apache.spark.sql.internal.SQLConf class TestMergeIntoTable extends HoodieSparkSqlTestBase with ScalaAssertionSupport { test("Test MergeInto Basic") { withRecordType()(withTempDir { tmp => + spark.sql("set hoodie.payload.combined.schema.validate = true") val tableName = generateTableName // Create table spark.sql( @@ -115,6 +117,7 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase with ScalaAssertionSuppo test("Test MergeInto with ignored record") { withRecordType()(withTempDir {tmp => + spark.sql("set hoodie.payload.combined.schema.validate = true") val sourceTable = generateTableName val targetTable = generateTableName // Create source table @@ -193,6 +196,7 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase with ScalaAssertionSuppo test("Test MergeInto for MOR table ") { withRecordType()(withTempDir {tmp => + spark.sql("set hoodie.payload.combined.schema.validate = true") val tableName = generateTableName // Create a mor partitioned table. spark.sql( @@ -277,7 +281,8 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase with ScalaAssertionSuppo | select 2 as s_id, 'a2' as s_name, 15 as s_price, 1001 as s_ts, '2021-03-21' as dt | ) s0 | on t0.id = s0.s_id - | when matched and s_ts = 1001 then update set * + | when matched and s_ts = 1001 + | then update set id = s_id, name = s_name, price = s_price, ts = s_ts, t0.dt = s0.dt """.stripMargin ) checkAnswer(s"select id,name,price,dt from $tableName order by id")( @@ -286,7 +291,15 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase with ScalaAssertionSuppo ) // Delete with condition expression. - spark.sql( + val errorMessage = if (HoodieSparkUtils.gteqSpark3_2) { + "Only simple conditions of the form `t.id = s.id` are allowed on the primary-key column. Found `t0.id = (s0.s_id + 1)`" + } else if (HoodieSparkUtils.gteqSpark3_1) { + "Only simple conditions of the form `t.id = s.id` are allowed on the primary-key column. Found `t0.`id` = (s0.`s_id` + 1)`" + } else { + "Only simple conditions of the form `t.id = s.id` are allowed on the primary-key column. Found `t0.`id` = (s0.`s_id` + 1)`;" + } + + checkException( s""" | merge into $tableName t0 | using ( @@ -295,6 +308,17 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase with ScalaAssertionSuppo | on t0.id = s0.s_id + 1 | when matched and s_ts = 1001 then delete """.stripMargin + )(errorMessage) + + spark.sql( + s""" + | merge into $tableName t0 + | using ( + | select 2 as s_id, 'a2' as s_name, 15 as s_price, 1001 as ts, '2021-03-21' as dt + | ) s0 + | on t0.id = s0.s_id + | when matched and s0.ts = 1001 then delete + """.stripMargin ) checkAnswer(s"select id,name,price,dt from $tableName order by id")( Seq(1, "a1", 12, "2021-03-21") @@ -304,6 +328,7 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase with ScalaAssertionSuppo test("Test MergeInto with insert only") { withRecordType()(withTempDir {tmp => + spark.sql("set hoodie.payload.combined.schema.validate = true") // Create a partitioned mor table val tableName = generateTableName spark.sql( @@ -358,6 +383,7 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase with ScalaAssertionSuppo test("Test MergeInto For PreCombineField") { withRecordType()(withTempDir { tmp => + spark.sql("set hoodie.payload.combined.schema.validate = true") Seq("cow", "mor").foreach { tableType => val tableName1 = generateTableName // Create a mor partitioned table. @@ -431,6 +457,7 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase with ScalaAssertionSuppo test("Test MergeInto with preCombine field expression") { withRecordType()(withTempDir { tmp => + spark.sql("set hoodie.payload.combined.schema.validate = true") Seq("cow", "mor").foreach { tableType => val tableName1 = generateTableName spark.sql( @@ -453,8 +480,10 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase with ScalaAssertionSuppo // Insert data spark.sql(s"""insert into $tableName1 values(1, 'a1', 10, 1000, '2021-03-21')""") + // // Update data with a value expression on preCombine field // 1) set source column name to be same as target column + // spark.sql( s""" | merge into $tableName1 as t0 @@ -470,8 +499,16 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase with ScalaAssertionSuppo Seq(1, "a1", 22, "2021-03-21", 1001) ) + // // 2) set source column name to be different with target column - spark.sql( + // + val errorMessage = if (HoodieSparkUtils.gteqSpark3_1) { + "Failed to resolve pre-combine field `v` w/in the source-table output" + } else { + "Failed to resolve pre-combine field `v` w/in the source-table output;" + } + + checkException( s""" | merge into $tableName1 as t0 | using ( @@ -480,6 +517,17 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase with ScalaAssertionSuppo | on t0.id = s0.s_id | when matched then update set id=s0.s_id, name=s0.s_name, price=s0.s_price*2, v=s0.s_v+2, dt=s0.dt """.stripMargin + )(errorMessage) + + spark.sql( + s""" + | merge into $tableName1 as t0 + | using ( + | select 1 as s_id, 'a1' as s_name, 12 as s_price, 1000 as v, '2021-03-21' as dt + | ) as s0 + | on t0.id = s0.s_id + | when matched then update set id=s0.s_id, name=s0.s_name, price=s0.s_price*2, v=s0.v+2, dt=s0.dt + """.stripMargin ) // Update success as new value 1002 is bigger than original value 1001 checkAnswer(s"select id,name,price,dt,v from $tableName1")( @@ -491,6 +539,7 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase with ScalaAssertionSuppo test("Test MergeInto with primaryKey expression") { withRecordType()(withTempDir { tmp => + spark.sql("set hoodie.payload.combined.schema.validate = true") val tableName1 = generateTableName spark.sql( s""" @@ -514,15 +563,34 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase with ScalaAssertionSuppo spark.sql(s"""insert into $tableName1 values(2, 'a2', 20, 2000, '2021-03-21')""") spark.sql(s"""insert into $tableName1 values(1, 'a1', 10, 1000, '2021-03-21')""") + // // Delete data with a condition expression on primaryKey field // 1) set source column name to be same as target column + // + val complexConditionsErrorMessage = if (HoodieSparkUtils.gteqSpark3_2) { + "Only simple conditions of the form `t.id = s.id` are allowed on the primary-key column. Found `t0.id = (s0.id + 1)`" + } else if (HoodieSparkUtils.gteqSpark3_1) { + "Only simple conditions of the form `t.id = s.id` are allowed on the primary-key column. Found `t0.`id` = (s0.`id` + 1)`" + } else { + "Only simple conditions of the form `t.id = s.id` are allowed on the primary-key column. Found `t0.`id` = (s0.`id` + 1)`;" + } + + checkException( + s"""merge into $tableName1 t0 + | using ( + | select 1 as id, 'a1' as name, 15 as price, 1001 as v, '2021-03-21' as dt + | ) s0 + | on t0.id = s0.id + 1 + | when matched then delete + """.stripMargin)(complexConditionsErrorMessage) + spark.sql( s""" | merge into $tableName1 t0 | using ( - | select 1 as id, 'a1' as name, 15 as price, 1001 as v, '2021-03-21' as dt + | select 2 as id, 'a2' as name, 20 as price, 2000 as v, '2021-03-21' as dt | ) s0 - | on t0.id = s0.id + 1 + | on t0.id = s0.id | when matched then delete """.stripMargin ) @@ -531,17 +599,38 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase with ScalaAssertionSuppo Seq(3, "a3", 30, 3000, "2021-03-21") ) - // 2) set source column name to be different with target column + // + // 2.a) set source column name to be different with target column (should fail unable to match pre-combine field) + // + val failedToResolveErrorMessage = if (HoodieSparkUtils.gteqSpark3_1) { + "Failed to resolve pre-combine field `v` w/in the source-table output" + } else { + "Failed to resolve pre-combine field `v` w/in the source-table output;" + } + + checkException( + s"""merge into $tableName1 t0 + | using ( + | select 3 as s_id, 'a3' as s_name, 30 as s_price, 3000 as s_v, '2021-03-21' as dt + | ) s0 + | on t0.id = s0.s_id + | when matched then delete + |""".stripMargin)(failedToResolveErrorMessage) + + // + // 2.b) set source column name to be different with target column + // spark.sql( s""" | merge into $tableName1 t0 | using ( - | select 2 as s_id, 'a1' as s_name, 15 as s_price, 1001 as s_v, '2021-03-21' as dt + | select 3 as s_id, 'a3' as s_name, 30 as s_price, 3000 as v, '2021-03-21' as dt | ) s0 - | on t0.id = s0.s_id + 1 + | on t0.id = s0.s_id | when matched then delete """.stripMargin ) + checkAnswer(s"select id,name,price,v,dt from $tableName1 order by id")( Seq(1, "a1", 10, 1000, "2021-03-21") ) @@ -550,6 +639,7 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase with ScalaAssertionSuppo test("Test MergeInto with combination of delete update insert") { withRecordType()(withTempDir { tmp => + spark.sql("set hoodie.payload.combined.schema.validate = true") val sourceTable = generateTableName val targetTable = generateTableName // Create source table @@ -596,9 +686,9 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase with ScalaAssertionSuppo | merge into $targetTable as t0 | using $sourceTable as s0 | on t0.id = s0.id - | when matched and id = 10 then delete - | when matched and id < 10 then update set name='sxx', price=s0.price*2, ts=s0.ts+10000, dt=s0.dt - | when not matched and id > 10 then insert * + | when matched and s0.id = 10 then delete + | when matched and s0.id < 10 then update set id=s0.id, name='sxx', price=s0.price*2, ts=s0.ts+10000, dt=s0.dt + | when not matched and s0.id > 10 then insert * """.stripMargin) checkAnswer(s"select id,name,price,ts,dt from $targetTable order by id")( Seq(7, "a7", 70, 1007, "2021-03-21"), @@ -612,6 +702,8 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase with ScalaAssertionSuppo test("Merge Hudi to Hudi") { withRecordType()(withTempDir { tmp => + spark.sessionState.conf.setConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED, false) + spark.sql("set hoodie.payload.combined.schema.validate = true") Seq("cow", "mor").foreach { tableType => val sourceTable = generateTableName spark.sql( @@ -717,6 +809,7 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase with ScalaAssertionSuppo test("Test Different Type of PreCombineField") { withRecordType()(withTempDir { tmp => + spark.sql("set hoodie.payload.combined.schema.validate = true") val typeAndValue = Seq( ("string", "'1000'"), ("int", 1000), @@ -747,11 +840,10 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase with ScalaAssertionSuppo s""" | merge into $tableName | using ( - | select 1 as id, 'a1' as name, 10 as price, $dataValue as c0, '1' as flag + | select 1 as id, 'a1' as name, 10 as price, $dataValue as c, '1' as flag | ) s0 | on s0.id = $tableName.id - | when matched and flag = '1' then update set - | id = s0.id, name = s0.name, price = s0.price, c = s0.c0 + | when matched and flag = '1' then update set * | when not matched and flag = '1' then insert * """.stripMargin) checkAnswer(s"select id, name, price from $tableName")( @@ -777,6 +869,7 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase with ScalaAssertionSuppo test("Test MergeInto For MOR With Compaction On") { withRecordType()(withTempDir { tmp => + spark.sql("set hoodie.payload.combined.schema.validate = true") val tableName = generateTableName spark.sql( s""" @@ -827,6 +920,7 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase with ScalaAssertionSuppo test("Test MereInto With Null Fields") { withRecordType()(withTempDir { tmp => + spark.sql("set hoodie.payload.combined.schema.validate = true") val types = Seq( "string" , "int", @@ -872,6 +966,7 @@ class TestMergeIntoTable extends HoodieSparkSqlTestBase with ScalaAssertionSuppo test("Test MergeInto With All Kinds Of DataType") { withRecordType()(withTempDir { tmp => + spark.sql("set hoodie.payload.combined.schema.validate = true") val dataAndTypes = Seq( ("string", "'a1'"), ("int", "10"), diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable2.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable2.scala index 0a8458063cd5a..48ba4ea0ca020 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable2.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable2.scala @@ -25,6 +25,7 @@ class TestMergeIntoTable2 extends HoodieSparkSqlTestBase { test("Test MergeInto for MOR table 2") { withRecordType()(withTempDir { tmp => + spark.sql("set hoodie.payload.combined.schema.validate = true") val tableName = generateTableName // Create a mor partitioned table. spark.sql( @@ -142,6 +143,7 @@ class TestMergeIntoTable2 extends HoodieSparkSqlTestBase { test("Test Merge Into CTAS Table") { withRecordType()(withTempDir { tmp => + spark.sql("set hoodie.payload.combined.schema.validate = true") val tableName = generateTableName spark.sql( s""" @@ -163,9 +165,9 @@ class TestMergeIntoTable2 extends HoodieSparkSqlTestBase { s""" |merge into $tableName h0 |using ( - | select 1 as s_id, 'a1_1' as name + | select 1 as id, 'a1_1' as name |) s0 - |on h0.id = s0.s_id + |on h0.id = s0.id |when matched then update set * |""".stripMargin ) @@ -177,6 +179,7 @@ class TestMergeIntoTable2 extends HoodieSparkSqlTestBase { test("Test Merge With Complex Data Type") { withRecordType()(withTempDir { tmp => + spark.sql("set hoodie.payload.combined.schema.validate = true") val tableName = generateTableName spark.sql( s""" @@ -240,6 +243,7 @@ class TestMergeIntoTable2 extends HoodieSparkSqlTestBase { test("Test column name matching for insert * and update set *") { withRecordType()(withTempDir { tmp => + spark.sql("set hoodie.payload.combined.schema.validate = true") val tableName = generateTableName // Create table spark.sql( @@ -264,23 +268,27 @@ class TestMergeIntoTable2 extends HoodieSparkSqlTestBase { Seq(1, "a1", 1.0, 10, "2021-03-21") ) - // Test the order of column types in sourceTable is similar to that in targetTable + // NOTE: When using star update/insert clauses (ie `insert *` or `update *`) order of the + // columns in the source and target table _have to_ match (Spark won't be applying any column + // column resolution logic) spark.sql( s""" |merge into $tableName as t0 |using ( - | select 1 as id, '2021-05-05' as dt, 1002 as ts, 97 as price, 'a1' as name union all - | select 1 as id, '2021-05-05' as dt, 1003 as ts, 98 as price, 'a2' as name union all - | select 2 as id, '2021-05-05' as dt, 1001 as ts, 99 as price, 'a3' as name + | select 1 as id, 'a1' as name, 97 as price, 1002 as ts, '2021-05-05' as dt union all + | select 1 as id, 'a2' as name, 98 as price, 1003 as ts, '2021-05-05' as dt union all + | select 2 as id, 'a3' as name, 99 as price, 1001 as ts, '2021-05-05' as dt | ) as s0 |on t0.id = s0.id |when matched then update set * |when not matched then insert * |""".stripMargin) + checkAnswer(s"select id, name, price, ts, dt from $tableName")( Seq(1, "a2", 98.0, 1003, "2021-05-05"), Seq(2, "a3", 99.0, 1001, "2021-05-05") ) + // Test the order of the column types of sourceTable is different from the column types of targetTable spark.sql( s""" @@ -291,8 +299,8 @@ class TestMergeIntoTable2 extends HoodieSparkSqlTestBase { | select 3 as id, 'a3' as name, 1000 as ts, '2021-05-05' as dt, 102 as price | ) as s0 |on t0.id = s0.id - |when matched then update set * - |when not matched then insert * + |when matched then update set t0.name = s0.name, t0.ts = s0.ts, t0.dt = s0.dt, t0.price = s0.price + |when not matched then insert (id, name, ts, dt, price) values (s0.id, s0.name, s0.ts, s0.dt, s0.price) |""".stripMargin) checkAnswer(s"select id, name, price, ts, dt from $tableName")( Seq(1, "a1", 100.0, 1004, "2021-05-05"), @@ -305,8 +313,8 @@ class TestMergeIntoTable2 extends HoodieSparkSqlTestBase { s""" |merge into $tableName as t0 |using ( - | select 1 as id, 'a6' as name, 1006 as ts, '2021-05-05' as dt, 106 as price, '0' as flag union all - | select 4 as id, 'a4' as name, 1000 as ts, '2021-05-06' as dt, 100 as price, '1' as flag + | select 1 as id, 'a6' as name, 106 as price, 1006 as ts, '2021-05-05' as dt, '0' as flag union all + | select 4 as id, 'a4' as name, 100 as price, 1000 as ts, '2021-05-06' as dt, '1' as flag | ) as s0 |on t0.id = s0.id |when matched and flag = '1' then update set * @@ -323,6 +331,7 @@ class TestMergeIntoTable2 extends HoodieSparkSqlTestBase { test("Test MergeInto For Source Table With Column Aliases") { withRecordType()(withTempDir { tmp => + spark.sql("set hoodie.payload.combined.schema.validate = true") val tableName = generateTableName // Create table spark.sql( @@ -366,6 +375,7 @@ class TestMergeIntoTable2 extends HoodieSparkSqlTestBase { test("Test MergeInto When PrimaryKey And PreCombineField Of Source Table And Target Table Differ In Case Only") { withRecordType()(withTempDir { tmp => + spark.sql("set hoodie.payload.combined.schema.validate = true") val tableName = generateTableName // Create table spark.sql( @@ -390,8 +400,7 @@ class TestMergeIntoTable2 extends HoodieSparkSqlTestBase { | select 1 as ID, 'a1' as NAME, 10 as PRICE, 1000 as TS, '1' as FLAG | ) s0 | on s0.ID = $tableName.id - | when matched and FLAG = '1' then update set - | id = s0.ID, name = s0.NAME, price = s0.PRICE, ts = s0.TS + | when matched and FLAG = '1' then update set * | when not matched and FLAG = '1' then insert * |""".stripMargin) checkAnswer(s"select id, name, price, ts from $tableName")( @@ -406,8 +415,7 @@ class TestMergeIntoTable2 extends HoodieSparkSqlTestBase { | select 1 as ID, 'a1' as NAME, 11 as PRICE, 1001 as TS, '1' as FLAG | ) s0 | on s0.id = $tableName.id - | when matched and FLAG = '1' then update set - | id = s0.id, name = s0.NAME, price = s0.PRICE, ts = s0.ts + | when matched and FLAG = '1' then update set id = s0.id, name = s0.NAME, price = s0.PRICE, ts = s0.ts | when not matched and FLAG = '1' then insert * |""".stripMargin) checkAnswer(s"select id, name, price, ts from $tableName")( @@ -422,8 +430,7 @@ class TestMergeIntoTable2 extends HoodieSparkSqlTestBase { | select 2 as ID, 'a2' as NAME, 12 as PRICE, 1002 as TS, '1' as FLAG | ) s0 | on cast(s0.id as int) = $tableName.id - | when matched and FLAG = '1' then update set - | id = s0.id, name = s0.NAME, price = s0.PRICE, ts = s0.ts + | when matched and FLAG = '1' then update set id = s0.id, name = s0.NAME, price = s0.PRICE, ts = s0.ts | when not matched and FLAG = '1' then insert * |""".stripMargin) checkAnswer(s"select id, name, price, ts from $tableName")( @@ -435,6 +442,7 @@ class TestMergeIntoTable2 extends HoodieSparkSqlTestBase { test("Test ignoring case") { withRecordType()(withTempDir { tmp => + spark.sql("set hoodie.payload.combined.schema.validate = true") val tableName = generateTableName // Create table spark.sql( @@ -474,7 +482,7 @@ class TestMergeIntoTable2 extends HoodieSparkSqlTestBase { s""" | merge into $tableName | using ( - | select 1 as id, 'a1' as name, 20 as PRICE, '2021-05-05' as dt, 1001 as ts + | select 1 as id, 'a1' as name, 20 as PRICE, 1001 as ts, '2021-05-05' as dt | ) s0 | on s0.id = $tableName.id | when matched then update set @@ -490,8 +498,8 @@ class TestMergeIntoTable2 extends HoodieSparkSqlTestBase { s""" | merge into $tableName as t0 | using ( - | select 1 as id, 'a1' as name, 1111 as ts, '2021-05-05' as dt, 111 as PRICE union all - | select 2 as id, 'a2' as name, 1112 as ts, '2021-05-05' as dt, 112 as PRICE + | select 1 as id, 'a1' as name, 111 as PRICE, 1111 as ts, '2021-05-05' as dt union all + | select 2 as id, 'a2' as name, 112 as PRICE, 1112 as ts, '2021-05-05' as dt | ) as s0 | on t0.id = s0.id | when matched then update set * @@ -506,6 +514,7 @@ class TestMergeIntoTable2 extends HoodieSparkSqlTestBase { test("Test ignoring case for MOR table") { withRecordType()(withTempDir { tmp => + spark.sql("set hoodie.payload.combined.schema.validate = true") val tableName = generateTableName // Create a mor partitioned table. spark.sql( @@ -531,7 +540,7 @@ class TestMergeIntoTable2 extends HoodieSparkSqlTestBase { s""" | merge into $tableName as t0 | using ( - | select 1 as id, 'a1' as NAME, 1111 as ts, '2021-05-05' as DT, 111 as price + | select 1 as id, 'a1' as NAME, 111 as price, 1111 as ts, '2021-05-05' as DT | ) as s0 | on t0.id = s0.id | when matched then update set * @@ -546,6 +555,7 @@ class TestMergeIntoTable2 extends HoodieSparkSqlTestBase { test("Test only insert when source table contains history") { withRecordType()(withTempDir { tmp => + spark.sql("set hoodie.payload.combined.schema.validate = true") val tableName = generateTableName // Create table spark.sql( @@ -591,6 +601,7 @@ class TestMergeIntoTable2 extends HoodieSparkSqlTestBase { test("Test only insert when source table contains history and target table has multiple keys") { withRecordType()(withTempDir { tmp => + spark.sql("set hoodie.payload.combined.schema.validate = true") val tableName = generateTableName // Create table with multiple keys spark.sql( @@ -661,7 +672,7 @@ class TestMergeIntoTable2 extends HoodieSparkSqlTestBase { s""" | merge into $tableName as t0 | using ( - | select 'a1' as name, 1 as id, 10 as price, 1000 as ts, '2021-03-21' as dt + | select 1 as id, 'a1' as name, 10 as price, 1000 as ts, '2021-03-21' as dt | ) as s0 | on t0.id = s0.id | when not matched and s0.id % 2 = 1 then insert * @@ -738,7 +749,7 @@ class TestMergeIntoTable2 extends HoodieSparkSqlTestBase { s""" | merge into $tableName as t0 | using ( - | select 'a2' as name, 1 as id, 1000 as ts + | select 1 as id, 'a2' as name, 1000 as ts | ) as s0 | on t0.id = s0.id | when matched then update set t0.name = s0.name, t0.ts = s0.ts diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestPartialUpdateForMergeInto.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestPartialUpdateForMergeInto.scala index 1af7a162be185..2284d76ab3a9a 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestPartialUpdateForMergeInto.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestPartialUpdateForMergeInto.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hudi +import org.apache.hudi.HoodieSparkUtils + class TestPartialUpdateForMergeInto extends HoodieSparkSqlTestBase { test("Test Partial Update") { @@ -98,15 +100,19 @@ class TestPartialUpdateForMergeInto extends HoodieSparkSqlTestBase { | preCombineField = '_ts' |)""".stripMargin) + val failedToResolveErrorMessage = if (HoodieSparkUtils.gteqSpark3_1) { + "Failed to resolve pre-combine field `_ts` w/in the source-table output" + } else { + "Failed to resolve pre-combine field `_ts` w/in the source-table output;" + } + checkExceptionContain( s""" |merge into $tableName t0 |using ( select 1 as id, 'a1' as name, 12 as price) s0 |on t0.id = s0.id |when matched then update set price = s0.price - """.stripMargin)( - "Missing specify value for the preCombineField: _ts in merge-into update action. " + - "You should add '... update set _ts = xx....' to the when-matched clause.") + """.stripMargin)(failedToResolveErrorMessage) val tableName2 = generateTableName spark.sql( @@ -122,16 +128,5 @@ class TestPartialUpdateForMergeInto extends HoodieSparkSqlTestBase { | primaryKey = 'id', | preCombineField = '_ts' |)""".stripMargin) - - checkExceptionContain( - s""" - |merge into $tableName2 t0 - |using ( select 1 as id, 'a1' as name, 12 as price, 1000 as ts) s0 - |on t0.id = s0.id - |when matched then update set price = s0.price, _ts = s0.ts - """.stripMargin)( - "Missing specify the value for target field: 'id' in merge into update action for MOR table. " + - "Currently we cannot support partial update for MOR, please complete all the target fields " + - "just like '...update set id = s0.id, name = s0.name ....'") } } diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/command/index/TestIndexSyntax.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/command/index/TestIndexSyntax.scala index 537d3ad6a305a..448c32294f601 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/command/index/TestIndexSyntax.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/command/index/TestIndexSyntax.scala @@ -56,11 +56,11 @@ class TestIndexSyntax extends HoodieSparkSqlTestBase { var logicalPlan = sqlParser.parsePlan(s"show indexes from default.$tableName") var resolvedLogicalPlan = analyzer.execute(logicalPlan) - assertResult(s"`default`.`$tableName`")(resolvedLogicalPlan.asInstanceOf[ShowIndexesCommand].tableId.toString()) + assertResult(s"`default`.`$tableName`")(resolvedLogicalPlan.asInstanceOf[ShowIndexesCommand].table.identifier.quotedString) logicalPlan = sqlParser.parsePlan(s"create index idx_name on $tableName using lucene (name) options(block_size=1024)") resolvedLogicalPlan = analyzer.execute(logicalPlan) - assertResult(s"`default`.`$tableName`")(resolvedLogicalPlan.asInstanceOf[CreateIndexCommand].tableId.toString()) + assertResult(s"`default`.`$tableName`")(resolvedLogicalPlan.asInstanceOf[CreateIndexCommand].table.identifier.quotedString) assertResult("idx_name")(resolvedLogicalPlan.asInstanceOf[CreateIndexCommand].indexName) assertResult("lucene")(resolvedLogicalPlan.asInstanceOf[CreateIndexCommand].indexType) assertResult(false)(resolvedLogicalPlan.asInstanceOf[CreateIndexCommand].ignoreIfExists) @@ -68,7 +68,7 @@ class TestIndexSyntax extends HoodieSparkSqlTestBase { logicalPlan = sqlParser.parsePlan(s"create index if not exists idx_price on $tableName using lucene (price options(order='desc')) options(block_size=512)") resolvedLogicalPlan = analyzer.execute(logicalPlan) - assertResult(s"`default`.`$tableName`")(resolvedLogicalPlan.asInstanceOf[CreateIndexCommand].tableId.toString()) + assertResult(s"`default`.`$tableName`")(resolvedLogicalPlan.asInstanceOf[CreateIndexCommand].table.identifier.quotedString) assertResult("idx_price")(resolvedLogicalPlan.asInstanceOf[CreateIndexCommand].indexName) assertResult("lucene")(resolvedLogicalPlan.asInstanceOf[CreateIndexCommand].indexType) assertResult(Map("order" -> "desc"))(resolvedLogicalPlan.asInstanceOf[CreateIndexCommand].columns.head._2) @@ -76,7 +76,7 @@ class TestIndexSyntax extends HoodieSparkSqlTestBase { logicalPlan = sqlParser.parsePlan(s"drop index if exists idx_name on $tableName") resolvedLogicalPlan = analyzer.execute(logicalPlan) - assertResult(s"`default`.`$tableName`")(resolvedLogicalPlan.asInstanceOf[DropIndexCommand].tableId.toString()) + assertResult(s"`default`.`$tableName`")(resolvedLogicalPlan.asInstanceOf[DropIndexCommand].table.identifier.quotedString) assertResult("idx_name")(resolvedLogicalPlan.asInstanceOf[DropIndexCommand].indexName) assertResult(true)(resolvedLogicalPlan.asInstanceOf[DropIndexCommand].ignoreIfNotExists) } diff --git a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/HoodieSpark2CatalystExpressionUtils.scala b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/HoodieSpark2CatalystExpressionUtils.scala index 92ccf02cab60b..ea5841ecdf43a 100644 --- a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/HoodieSpark2CatalystExpressionUtils.scala +++ b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/HoodieSpark2CatalystExpressionUtils.scala @@ -75,6 +75,12 @@ object HoodieSpark2CatalystExpressionUtils extends HoodieCatalystExpressionUtils } } + override def matchCast(expr: Expression): Option[(Expression, DataType, Option[String])] = + expr match { + case Cast(child, dataType, timeZoneId) => Some((child, dataType, timeZoneId)) + case _ => None + } + override def tryMatchAttributeOrderingPreservingTransformation(expr: Expression): Option[AttributeReference] = { expr match { case OrderPreservingTransformation(attrRef) => Some(attrRef) diff --git a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/HoodieSpark2CatalystPlanUtils.scala b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/HoodieSpark2CatalystPlanUtils.scala index 9be7198e6d432..fd84cfc201f96 100644 --- a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/HoodieSpark2CatalystPlanUtils.scala +++ b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/HoodieSpark2CatalystPlanUtils.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedRelation} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Like} import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts import org.apache.spark.sql.catalyst.plans.JoinType @@ -44,43 +44,20 @@ object HoodieSpark2CatalystPlanUtils extends HoodieCatalystPlansUtils { def createExplainCommand(plan: LogicalPlan, extended: Boolean): LogicalPlan = ExplainCommand(plan, extended = extended) - override def toTableIdentifier(aliasId: AliasIdentifier): TableIdentifier = { - TableIdentifier(aliasId.identifier, aliasId.database) - } - - override def toTableIdentifier(relation: UnresolvedRelation): TableIdentifier = { - relation.tableIdentifier - } - override def createJoin(left: LogicalPlan, right: LogicalPlan, joinType: JoinType): Join = { Join(left, right, joinType, None) } - override def isInsertInto(plan: LogicalPlan): Boolean = { - plan.isInstanceOf[InsertIntoTable] - } - - override def getInsertIntoChildren(plan: LogicalPlan): - Option[(LogicalPlan, Map[String, Option[String]], LogicalPlan, Boolean, Boolean)] = { + override def unapplyInsertIntoStatement(plan: LogicalPlan): Option[(LogicalPlan, Map[String, Option[String]], LogicalPlan, Boolean, Boolean)] = { plan match { case InsertIntoTable(table, partition, query, overwrite, ifPartitionNotExists) => Some((table, partition, query, overwrite, ifPartitionNotExists)) - case _=> None + case _ => None } } - override def createInsertInto(table: LogicalPlan, partition: Map[String, Option[String]], - query: LogicalPlan, overwrite: Boolean, ifPartitionNotExists: Boolean): LogicalPlan = { - InsertIntoTable(table, partition, query, overwrite, ifPartitionNotExists) - } - - override def isRelationTimeTravel(plan: LogicalPlan): Boolean = { - false - } - - override def getRelationTimeTravel(plan: LogicalPlan): Option[(LogicalPlan, Option[Expression], Option[String])] = { - throw new IllegalStateException(s"Should not call getRelationTimeTravel for spark2") - } + def rebaseInsertIntoStatement(iis: LogicalPlan, targetTable: LogicalPlan, query: LogicalPlan): LogicalPlan = + iis.asInstanceOf[InsertIntoTable].copy(table = targetTable, query = query) override def isRepairTable(plan: LogicalPlan): Boolean = { plan.isInstanceOf[AlterTableRecoverPartitionsCommand] diff --git a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/adapter/Spark2Adapter.scala b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/adapter/Spark2Adapter.scala index c4d7c1ff5b05e..dd72282fec61e 100644 --- a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/adapter/Spark2Adapter.scala +++ b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/adapter/Spark2Adapter.scala @@ -26,6 +26,8 @@ import org.apache.hudi.{AvroConversionUtils, DefaultSource, HoodieBaseRelation, import org.apache.spark.sql._ import org.apache.spark.sql.avro._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases +import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, InterpretedPredicate} @@ -36,8 +38,10 @@ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, Sp import org.apache.spark.sql.execution.vectorized.MutableColumnarRow import org.apache.spark.sql.hudi.SparkAdapter import org.apache.spark.sql.hudi.parser.HoodieSpark2ExtendedSqlParser +import org.apache.spark.sql.parser.HoodieExtendedParserInterface import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.{DataType, Metadata, MetadataBuilder, StructType} import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel._ @@ -55,14 +59,19 @@ class Spark2Adapter extends SparkAdapter { r.isInstanceOf[MutableColumnarRow] } + def createCatalystMetadataForMetaField: Metadata = + // NOTE: Since [[METADATA_COL_ATTR_KEY]] flag is not available in Spark 2.x, + // we simply produce an empty [[Metadata]] instance + new MetadataBuilder().build() + override def getCatalogUtils: HoodieCatalogUtils = { throw new UnsupportedOperationException("Catalog utilities are not supported in Spark 2.x"); } - override def getCatalystExpressionUtils: HoodieCatalystExpressionUtils = HoodieSpark2CatalystExpressionUtils - override def getCatalystPlanUtils: HoodieCatalystPlansUtils = HoodieSpark2CatalystPlanUtils + override def getCatalystExpressionUtils: HoodieCatalystExpressionUtils = HoodieSpark2CatalystExpressionUtils + override def createAvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean): HoodieAvroSerializer = new HoodieSpark2_4AvroSerializer(rootCatalystType, rootAvroType, nullable) @@ -76,18 +85,11 @@ class Spark2Adapter extends SparkAdapter { new Spark2RowSerDe(encoder) } - override def createExtendedSparkParser: Option[(SparkSession, ParserInterface) => ParserInterface] = { - Some( - (spark: SparkSession, delegate: ParserInterface) => new HoodieSpark2ExtendedSqlParser(spark, delegate) - ) - } + override def createExtendedSparkParser(spark: SparkSession, delegate: ParserInterface): HoodieExtendedParserInterface = + new HoodieSpark2ExtendedSqlParser(spark, delegate) override def getSparkParsePartitionUtil: SparkParsePartitionUtil = Spark2ParsePartitionUtil - override def parseMultipartIdentifier(parser: ParserInterface, sqlText: String): Seq[String] = { - throw new IllegalStateException(s"Should not call ParserInterface#parseMultipartIdentifier for spark2") - } - /** * Combine [[PartitionedFile]] to [[FilePartition]] according to `maxSplitBytes`. * @@ -128,23 +130,6 @@ class Spark2Adapter extends SparkAdapter { partitions.toSeq } - override def isHoodieTable(table: LogicalPlan, spark: SparkSession): Boolean = { - super.isHoodieTable(table, spark) || - // NOTE: Following checks extending the logic of the base class specifically for Spark 2.x - (unfoldSubqueryAliases(table) match { - // This is to handle the cases when table is loaded by providing - // the path to the Spark DS and not from the catalog - // - // NOTE: This logic can't be relocated to the hudi-spark-client - case LogicalRelation(_: HoodieBaseRelation, _, _, _) => true - - case relation: UnresolvedRelation => - isHoodieTable(getCatalystPlanUtils.toTableIdentifier(relation), spark) - - case _ => false - }) - } - override def createHoodieParquetFileFormat(appendPartitionValues: Boolean): Option[ParquetFileFormat] = { Some(new Spark24HoodieParquetFileFormat(appendPartitionValues)) } @@ -170,22 +155,10 @@ class Spark2Adapter extends SparkAdapter { new Spark2HoodieFileScanRDD(sparkSession, readFunction, filePartitions) } - override def resolveDeleteFromTable(deleteFromTable: Command, - resolveExpression: Expression => Expression): DeleteFromTable = { - val deleteFromTableCommand = deleteFromTable.asInstanceOf[DeleteFromTable] - val resolvedCondition = deleteFromTableCommand.condition.map(resolveExpression) - DeleteFromTable(deleteFromTableCommand.table, resolvedCondition) - } - override def extractDeleteCondition(deleteFromTable: Command): Expression = { deleteFromTable.asInstanceOf[DeleteFromTable].condition.getOrElse(null) } - override def getQueryParserFromExtendedSqlParser(session: SparkSession, delegate: ParserInterface, - sqlText: String): LogicalPlan = { - throw new UnsupportedOperationException(s"Unsupported parseQuery method in Spark earlier than Spark 3.3.0") - } - override def convertStorageLevelToString(level: StorageLevel): String = level match { case NONE => "NONE" case DISK_ONLY => "DISK_ONLY" diff --git a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/catalyst/analysis/HoodieSpark2Analysis.scala b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/catalyst/analysis/HoodieSpark2Analysis.scala new file mode 100644 index 0000000000000..7691b90bfd74a --- /dev/null +++ b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/catalyst/analysis/HoodieSpark2Analysis.scala @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Expression, ExtractValue, GetStructField, LambdaFunction} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Assignment, DeleteAction, InsertAction, LogicalPlan, MergeIntoTable, Project, UpdateAction, Window} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.toPrettySQL + +/** + * NOTE: This code is borrowed from Spark 3.1.3 + * This code is borrowed, so that we can have some advanced Spark SQL functionality (like Merge Into, for ex) + * in Spark 2.x + * + * PLEASE REFRAIN MAKING ANY CHANGES TO THIS CODE UNLESS ABSOLUTELY NECESSARY + */ +object HoodieSpark2Analysis { + + case class ResolveReferences(spark: SparkSession) extends Rule[LogicalPlan] { + + private val resolver = spark.sessionState.conf.resolver + + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + case m @ MergeIntoTable(targetTable, sourceTable, _, _, _) + if (!m.resolved || containsUnresolvedStarAssignments(m)) && targetTable.resolved && sourceTable.resolved => + + EliminateSubqueryAliases(targetTable) match { + case _ => + val newMatchedActions = m.matchedActions.map { + case DeleteAction(deleteCondition) => + val resolvedDeleteCondition = deleteCondition.map(resolveExpressionTopDown(_, m)) + DeleteAction(resolvedDeleteCondition) + case UpdateAction(updateCondition, assignments) => + val resolvedUpdateCondition = updateCondition.map(resolveExpressionTopDown(_, m)) + // The update value can access columns from both target and source tables. + UpdateAction( + resolvedUpdateCondition, + resolveAssignments(assignments, m, resolveValuesWithSourceOnly = false)) + case o => o + } + val newNotMatchedActions = m.notMatchedActions.map { + case InsertAction(insertCondition, assignments) => + // The insert action is used when not matched, so its condition and value can only + // access columns from the source table. + val resolvedInsertCondition = + insertCondition.map(resolveExpressionTopDown(_, Project(Nil, m.sourceTable))) + InsertAction( + resolvedInsertCondition, + resolveAssignments(assignments, m, resolveValuesWithSourceOnly = true)) + case o => o + } + val resolvedMergeCondition = resolveExpressionTopDown(m.mergeCondition, m) + m.copy(mergeCondition = resolvedMergeCondition, + matchedActions = newMatchedActions, + notMatchedActions = newNotMatchedActions) + } + } + + private def resolveAssignments(assignments: Seq[Assignment], + mergeInto: MergeIntoTable, + resolveValuesWithSourceOnly: Boolean): Seq[Assignment] = { + if (assignments.isEmpty) { + val expandedColumns = mergeInto.targetTable.output + val expandedValues = mergeInto.sourceTable.output + expandedColumns.zip(expandedValues).map(kv => Assignment(kv._1, kv._2)) + } else { + assignments.map { assign => + val resolvedKey = assign.key match { + case c if !c.resolved => + resolveExpressionTopDown(c, Project(Nil, mergeInto.targetTable)) + case o => o + } + val resolvedValue = assign.value match { + // The update values may contain target and/or source references. + case c if !c.resolved => + if (resolveValuesWithSourceOnly) { + resolveExpressionTopDown(c, Project(Nil, mergeInto.sourceTable)) + } else { + resolveExpressionTopDown(c, mergeInto) + } + case o => o + } + Assignment(resolvedKey, resolvedValue) + } + } + } + + /** + * Resolves the attribute and extract value expressions(s) by traversing the + * input expression in top down manner. The traversal is done in top-down manner as + * we need to skip over unbound lambda function expression. The lambda expressions are + * resolved in a different rule [[ResolveLambdaVariables]] + * + * Example : + * SELECT transform(array(1, 2, 3), (x, i) -> x + i)" + * + * In the case above, x and i are resolved as lambda variables in [[ResolveLambdaVariables]] + * + * Note : In this routine, the unresolved attributes are resolved from the input plan's + * children attributes. + * + * @param e The expression need to be resolved. + * @param q The LogicalPlan whose children are used to resolve expression's attribute. + * @param trimAlias When true, trim unnecessary alias of `GetStructField`. Note that, + * we cannot trim the alias of top-level `GetStructField`, as we should + * resolve `UnresolvedAttribute` to a named expression. The caller side + * can trim the alias of top-level `GetStructField` if it's safe to do so. + * @return resolved Expression. + */ + private def resolveExpressionTopDown(e: Expression, + q: LogicalPlan, + trimAlias: Boolean = false): Expression = { + + def innerResolve(e: Expression, isTopLevel: Boolean): Expression = { + if (e.resolved) return e + e match { + case f: LambdaFunction if !f.bound => f + case u@UnresolvedAttribute(nameParts) => + // Leave unchanged if resolution fails. Hopefully will be resolved next round. + val resolved = + withPosition(u) { + q.resolveChildren(nameParts, resolver) + .orElse(resolveLiteralFunction(nameParts, u, q)) + .getOrElse(u) + } + val result = resolved match { + // As the comment of method `resolveExpressionTopDown`'s param `trimAlias` said, + // when trimAlias = true, we will trim unnecessary alias of `GetStructField` and + // we won't trim the alias of top-level `GetStructField`. Since we will call + // CleanupAliases later in Analyzer, trim non top-level unnecessary alias of + // `GetStructField` here is safe. + case Alias(s: GetStructField, _) if trimAlias && !isTopLevel => s + case others => others + } + logDebug(s"Resolving $u to $result") + result + case UnresolvedExtractValue(child, fieldExpr) if child.resolved => + ExtractValue(child, fieldExpr, resolver) + case _ => e.mapChildren(innerResolve(_, isTopLevel = false)) + } + } + + innerResolve(e, isTopLevel = true) + } + + /** + * Literal functions do not require the user to specify braces when calling them + * When an attributes is not resolvable, we try to resolve it as a literal function. + */ + private def resolveLiteralFunction(nameParts: Seq[String], + attribute: UnresolvedAttribute, + plan: LogicalPlan): Option[Expression] = { + if (nameParts.length != 1) return None + val isNamedExpression = plan match { + case Aggregate(_, aggregateExpressions, _) => aggregateExpressions.contains(attribute) + case Project(projectList, _) => projectList.contains(attribute) + case Window(windowExpressions, _, _, _) => windowExpressions.contains(attribute) + case _ => false + } + val wrapper: Expression => Expression = + if (isNamedExpression) f => Alias(f, toPrettySQL(f))() else identity + // support CURRENT_DATE and CURRENT_TIMESTAMP + val literalFunctions = Seq(CurrentDate(), CurrentTimestamp()) + val name = nameParts.head + val func = literalFunctions.find(e => caseInsensitiveResolution(e.prettyName, name)) + func.map(wrapper) + } + + //////////////////////////////////////////////////////////////////////////////////////////// + // Following section is amended to the original (Spark's) implementation + // >>> BEGINS + //////////////////////////////////////////////////////////////////////////////////////////// + + private def containsUnresolvedStarAssignments(mit: MergeIntoTable): Boolean = { + val containsUnresolvedInsertStar = mit.notMatchedActions.exists { + case InsertAction(_, assignments) => assignments.isEmpty + case _ => false + } + val containsUnresolvedUpdateStar = mit.matchedActions.exists { + case UpdateAction(_, assignments) => assignments.isEmpty + case _ => false + } + + containsUnresolvedInsertStar || containsUnresolvedUpdateStar + } + + //////////////////////////////////////////////////////////////////////////////////////////// + // <<< ENDS + //////////////////////////////////////////////////////////////////////////////////////////// + } + +} diff --git a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeInto.scala b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeInto.scala index cc621a5b9f58d..faa11f02c3876 100644 --- a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeInto.scala +++ b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeInto.scala @@ -33,6 +33,13 @@ case class MergeIntoTable( matchedActions: Seq[MergeAction], notMatchedActions: Seq[MergeAction]) extends Command { override def children: Seq[LogicalPlan] = Seq(targetTable, sourceTable) + + // NOTE: Overriding this field is necessary to disable application of the [[ResolveReferences]] + // of the standard resolution rule from Spark, such that [[MergeIntoTable]] is resolved + // by [[HoodieSpark2Analysis$ResolveReferences]] rule instead + override def childrenResolved: Boolean = false + + override lazy val resolved: Boolean = expressions.forall(_.resolved) && children.forall(_.resolved) } diff --git a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/hudi/parser/HoodieSpark2ExtendedSqlParser.scala b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/hudi/parser/HoodieSpark2ExtendedSqlParser.scala index ce32ae091f66b..f65f9188b4e5e 100644 --- a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/hudi/parser/HoodieSpark2ExtendedSqlParser.scala +++ b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/hudi/parser/HoodieSpark2ExtendedSqlParser.scala @@ -29,11 +29,12 @@ import org.apache.spark.sql.catalyst.parser.{ParseErrorListener, ParseException, import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.parser.HoodieExtendedParserInterface import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, SparkSession} class HoodieSpark2ExtendedSqlParser(session: SparkSession, delegate: ParserInterface) - extends ParserInterface with Logging { + extends HoodieExtendedParserInterface with Logging { private lazy val conf = session.sqlContext.conf private lazy val builder = new HoodieSpark2ExtendedSqlAstBuilder(conf, delegate) diff --git a/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/HoodieSpark3CatalystPlanUtils.scala b/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/HoodieSpark3CatalystPlanUtils.scala index 552eb320161ff..9ceb6fa5ccd3a 100644 --- a/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/HoodieSpark3CatalystPlanUtils.scala +++ b/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/HoodieSpark3CatalystPlanUtils.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql -import org.apache.hudi.spark3.internal.ReflectUtil -import org.apache.spark.sql.catalyst.analysis.{TableOutputResolver, UnresolvedRelation} +import org.apache.hudi.SparkAdapterSupport +import org.apache.spark.sql.catalyst.analysis.TableOutputResolver import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, ProjectionOverSchema} import org.apache.spark.sql.catalyst.plans.JoinType -import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoStatement, Join, JoinHint, LogicalPlan} -import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} -import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ +import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoStatement, Join, JoinHint, LeafNode, LogicalPlan} +import org.apache.spark.sql.connector.catalog.{Identifier, Table, TableCatalog} import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.{ExtendedMode, SimpleMode} import org.apache.spark.sql.internal.SQLConf @@ -36,7 +35,12 @@ trait HoodieSpark3CatalystPlanUtils extends HoodieCatalystPlansUtils { */ def projectOverSchema(schema: StructType, output: AttributeSet): ProjectionOverSchema - override def resolveOutputColumns(tableName: String, + /** + * Un-applies [[ResolvedTable]] that had its signature changed in Spark 3.2 + */ + def unapplyResolvedTable(plan: LogicalPlan): Option[(TableCatalog, Identifier, Table)] + + def resolveOutputColumns(tableName: String, expected: Seq[Attribute], query: LogicalPlan, byName: Boolean, @@ -46,32 +50,11 @@ trait HoodieSpark3CatalystPlanUtils extends HoodieCatalystPlansUtils { override def createExplainCommand(plan: LogicalPlan, extended: Boolean): LogicalPlan = ExplainCommand(plan, mode = if (extended) ExtendedMode else SimpleMode) - override def toTableIdentifier(aliasId: AliasIdentifier): TableIdentifier = { - aliasId match { - case AliasIdentifier(name, Seq(database)) => - TableIdentifier(name, Some(database)) - case AliasIdentifier(name, Seq(_, database)) => - TableIdentifier(name, Some(database)) - case AliasIdentifier(name, Seq()) => - TableIdentifier(name, None) - case _ => throw new IllegalArgumentException(s"Cannot cast $aliasId to TableIdentifier") - } - } - - override def toTableIdentifier(relation: UnresolvedRelation): TableIdentifier = { - relation.multipartIdentifier.asTableIdentifier - } - override def createJoin(left: LogicalPlan, right: LogicalPlan, joinType: JoinType): Join = { Join(left, right, joinType, None, JoinHint.NONE) } - override def isInsertInto(plan: LogicalPlan): Boolean = { - plan.isInstanceOf[InsertIntoStatement] - } - - override def getInsertIntoChildren(plan: LogicalPlan): - Option[(LogicalPlan, Map[String, Option[String]], LogicalPlan, Boolean, Boolean)] = { + override def unapplyInsertIntoStatement(plan: LogicalPlan): Option[(LogicalPlan, Map[String, Option[String]], LogicalPlan, Boolean, Boolean)] = { plan match { case insert: InsertIntoStatement => Some((insert.table, insert.partitionSpec, insert.query, insert.overwrite, insert.ifPartitionNotExists)) @@ -80,8 +63,20 @@ trait HoodieSpark3CatalystPlanUtils extends HoodieCatalystPlansUtils { } } - override def createInsertInto(table: LogicalPlan, partition: Map[String, Option[String]], - query: LogicalPlan, overwrite: Boolean, ifPartitionNotExists: Boolean): LogicalPlan = { - ReflectUtil.createInsertInto(table, partition, Seq.empty[String], query, overwrite, ifPartitionNotExists) + def rebaseInsertIntoStatement(iis: LogicalPlan, targetTable: LogicalPlan, query: LogicalPlan): LogicalPlan = + iis.asInstanceOf[InsertIntoStatement].copy(table = targetTable, query = query) +} + +object HoodieSpark3CatalystPlanUtils extends SparkAdapterSupport { + + /** + * This is an extractor to accommodate for [[ResolvedTable]] signature change in Spark 3.2 + */ + object MatchResolvedTable { + def unapply(plan: LogicalPlan): Option[(TableCatalog, Identifier, Table)] = + sparkAdapter.getCatalystPlanUtils match { + case spark3Utils: HoodieSpark3CatalystPlanUtils => spark3Utils.unapplyResolvedTable(plan) + case _ => None + } } } diff --git a/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/adapter/BaseSpark3Adapter.scala b/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/adapter/BaseSpark3Adapter.scala index 1f82ce260edcf..022a1d247e18d 100644 --- a/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/adapter/BaseSpark3Adapter.scala +++ b/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/adapter/BaseSpark3Adapter.scala @@ -21,15 +21,18 @@ import org.apache.avro.Schema import org.apache.hadoop.fs.Path import org.apache.hudi.client.utils.SparkRowSerDe import org.apache.hudi.common.table.HoodieTableMetaClient +import org.apache.hudi.{AvroConversionUtils, DefaultSource, Spark3RowSerDe} import org.apache.hudi.{AvroConversionUtils, DefaultSource, HoodieBaseRelation, Spark3RowSerDe} import org.apache.spark.internal.Logging +import org.apache.spark.sql.{HoodieSpark3CatalogUtils, SQLContext, SparkSession} import org.apache.spark.sql.avro.{HoodieAvroSchemaConverters, HoodieSparkAvroSchemaConverters} -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases +import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Expression, InterpretedPredicate, Predicate} -import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.connector.catalog.V2TableWithV1Fallback import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.hudi.SparkAdapter @@ -40,7 +43,6 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel._ import scala.collection.JavaConverters.mapAsScalaMapConverter -import scala.util.control.NonFatal /** * Base implementation of [[SparkAdapter]] for Spark 3.x branch @@ -54,14 +56,24 @@ abstract class BaseSpark3Adapter extends SparkAdapter with Logging { new Spark3RowSerDe(encoder) } + override def resolveHoodieTable(plan: LogicalPlan): Option[CatalogTable] = { + super.resolveHoodieTable(plan).orElse { + EliminateSubqueryAliases(plan) match { + // First, we need to weed out unresolved plans + case plan if !plan.resolved => None + // NOTE: When resolving Hudi table we allow [[Filter]]s and [[Project]]s be applied + // on top of it + case PhysicalOperation(_, _, DataSourceV2Relation(v2: V2TableWithV1Fallback, _, _, _, _)) if isHoodieTable(v2.v1Table) => + Some(v2.v1Table) + case _ => None + } + } + } + override def getAvroSchemaConverters: HoodieAvroSchemaConverters = HoodieSparkAvroSchemaConverters override def getSparkParsePartitionUtil: SparkParsePartitionUtil = Spark3ParsePartitionUtil - override def parseMultipartIdentifier(parser: ParserInterface, sqlText: String): Seq[String] = { - parser.parseMultipartIdentifier(sqlText) - } - /** * Combine [[PartitionedFile]] to [[FilePartition]] according to `maxSplitBytes`. */ @@ -72,31 +84,6 @@ abstract class BaseSpark3Adapter extends SparkAdapter with Logging { FilePartition.getFilePartitions(sparkSession, partitionedFiles, maxSplitBytes) } - override def isHoodieTable(table: LogicalPlan, spark: SparkSession): Boolean = { - super.isHoodieTable(table, spark) || - // NOTE: Following checks extending the logic of the base class specifically for Spark 3.x - (unfoldSubqueryAliases(table) match { - case DataSourceV2Relation(table: Table, _, _, _, _) => isHoodieTable(table.properties()) - // This is to handle the cases when table is loaded by providing - // the path to the Spark DS and not from the catalog - // - // NOTE: This logic can't be relocated to the hudi-spark-client - case LogicalRelation(_: HoodieBaseRelation, _, _, _) => true - - case relation: UnresolvedRelation => - // TODO(HUDI-4503) clean-up try catch - try { - isHoodieTable(getCatalystPlanUtils.toTableIdentifier(relation), spark) - } catch { - case NonFatal(e) => - logWarning("Failed to determine whether the table is a hoodie table", e) - false - } - - case _ => false - }) - } - override def createInterpretedPredicate(e: Expression): InterpretedPredicate = { Predicate.createInterpreted(e) } diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/HoodieSpark31CatalystExpressionUtils.scala b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/HoodieSpark31CatalystExpressionUtils.scala index d31c6a7b1a281..f565df80750cc 100644 --- a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/HoodieSpark31CatalystExpressionUtils.scala +++ b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/HoodieSpark31CatalystExpressionUtils.scala @@ -24,6 +24,12 @@ import org.apache.spark.sql.types.DataType object HoodieSpark31CatalystExpressionUtils extends HoodieSpark3CatalystExpressionUtils { + override def matchCast(expr: Expression): Option[(Expression, DataType, Option[String])] = + expr match { + case Cast(child, dataType, timeZoneId) => Some((child, dataType, timeZoneId)) + case _ => None + } + override def tryMatchAttributeOrderingPreservingTransformation(expr: Expression): Option[AttributeReference] = { expr match { case OrderPreservingTransformation(attrRef) => Some(attrRef) diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/HoodieSpark31CatalystPlanUtils.scala b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/HoodieSpark31CatalystPlanUtils.scala index a4016f18cc614..6a53e23898683 100644 --- a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/HoodieSpark31CatalystPlanUtils.scala +++ b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/HoodieSpark31CatalystPlanUtils.scala @@ -19,18 +19,20 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Expression, ProjectionOverSchema} +import org.apache.spark.sql.catalyst.analysis.ResolvedTable +import org.apache.spark.sql.catalyst.expressions.{AttributeSet, ProjectionOverSchema} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connector.catalog.{Identifier, Table, TableCatalog} import org.apache.spark.sql.execution.command.AlterTableRecoverPartitionsCommand import org.apache.spark.sql.types.StructType object HoodieSpark31CatalystPlanUtils extends HoodieSpark3CatalystPlanUtils { - override def isRelationTimeTravel(plan: LogicalPlan): Boolean = false - - override def getRelationTimeTravel(plan: LogicalPlan): Option[(LogicalPlan, Option[Expression], Option[String])] = { - throw new IllegalStateException(s"Should not call getRelationTimeTravel for Spark <= 3.2.x") - } + def unapplyResolvedTable(plan: LogicalPlan): Option[(TableCatalog, Identifier, Table)] = + plan match { + case ResolvedTable(catalog, identifier, table) => Some((catalog, identifier, table)) + case _ => None + } override def projectOverSchema(schema: StructType, output: AttributeSet): ProjectionOverSchema = ProjectionOverSchema(schema) diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/adapter/Spark3_1Adapter.scala b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/adapter/Spark3_1Adapter.scala index 0a4bc289b35e7..ae515a0ad320e 100644 --- a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/adapter/Spark3_1Adapter.scala +++ b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/adapter/Spark3_1Adapter.scala @@ -18,20 +18,24 @@ package org.apache.spark.sql.adapter -import org.apache.hudi.Spark31HoodieFileScanRDD import org.apache.avro.Schema -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} +import org.apache.hudi.Spark31HoodieFileScanRDD import org.apache.spark.sql.avro.{HoodieAvroDeserializer, HoodieAvroSerializer, HoodieSpark3_1AvroDeserializer, HoodieSpark3_1AvroSerializer} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} import org.apache.spark.sql.catalyst.parser.ParserInterface -import org.apache.spark.sql.parser.HoodieSpark3_1ExtendedSqlParser -import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, PartitionedFile} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, Spark31HoodieParquetFileFormat} +import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, LogicalRelation, PartitionedFile} import org.apache.spark.sql.hudi.SparkAdapter -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.parser.{HoodieExtendedParserInterface, HoodieSpark3_1ExtendedSqlParser} +import org.apache.spark.sql.types.{DataType, Metadata, MetadataBuilder, StructType} import org.apache.spark.sql.vectorized.ColumnarUtils -import org.apache.spark.sql.{HoodieCatalystExpressionUtils, HoodieCatalystPlansUtils, HoodieSpark31CatalogUtils, HoodieSpark31CatalystExpressionUtils, HoodieSpark31CatalystPlanUtils, HoodieSpark3CatalogUtils, SparkSession} +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases +import org.apache.spark.sql.catalyst.catalog.CatalogTable +import org.apache.spark.sql.connector.catalog.V2TableWithV1Fallback +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation /** * Implementation of [[SparkAdapter]] for Spark 3.1.x @@ -40,24 +44,25 @@ class Spark3_1Adapter extends BaseSpark3Adapter { override def isColumnarBatchRow(r: InternalRow): Boolean = ColumnarUtils.isColumnarBatchRow(r) - override def getCatalogUtils: HoodieSpark3CatalogUtils = HoodieSpark31CatalogUtils + def createCatalystMetadataForMetaField: Metadata = + // NOTE: Since [[METADATA_COL_ATTR_KEY]] flag is not available in Spark 2.x, + // we simply produce an empty [[Metadata]] instance + new MetadataBuilder().build() - override def getCatalystExpressionUtils: HoodieCatalystExpressionUtils = HoodieSpark31CatalystExpressionUtils + override def getCatalogUtils: HoodieSpark3CatalogUtils = HoodieSpark31CatalogUtils override def getCatalystPlanUtils: HoodieCatalystPlansUtils = HoodieSpark31CatalystPlanUtils + override def getCatalystExpressionUtils: HoodieCatalystExpressionUtils = HoodieSpark31CatalystExpressionUtils + override def createAvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean): HoodieAvroSerializer = new HoodieSpark3_1AvroSerializer(rootCatalystType, rootAvroType, nullable) override def createAvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType): HoodieAvroDeserializer = new HoodieSpark3_1AvroDeserializer(rootAvroType, rootCatalystType) - override def createExtendedSparkParser: Option[(SparkSession, ParserInterface) => ParserInterface] = { - // since spark3.2.1 support datasourceV2, so we need to a new SqlParser to deal DDL statment - Some( - (spark: SparkSession, delegate: ParserInterface) => new HoodieSpark3_1ExtendedSqlParser(spark, delegate) - ) - } + override def createExtendedSparkParser(spark: SparkSession, delegate: ParserInterface): HoodieExtendedParserInterface = + new HoodieSpark3_1ExtendedSqlParser(spark, delegate) override def createHoodieParquetFileFormat(appendPartitionValues: Boolean): Option[ParquetFileFormat] = { Some(new Spark31HoodieParquetFileFormat(appendPartitionValues)) @@ -71,13 +76,6 @@ class Spark3_1Adapter extends BaseSpark3Adapter { new Spark31HoodieFileScanRDD(sparkSession, readFunction, filePartitions) } - override def resolveDeleteFromTable(deleteFromTable: Command, - resolveExpression: Expression => Expression): DeleteFromTable = { - val deleteFromTableCommand = deleteFromTable.asInstanceOf[DeleteFromTable] - val resolvedCondition = deleteFromTableCommand.condition.map(resolveExpression) - DeleteFromTable(deleteFromTableCommand.table, resolvedCondition) - } - override def extractDeleteCondition(deleteFromTable: Command): Expression = { deleteFromTable.asInstanceOf[DeleteFromTable].condition.getOrElse(null) } diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark312SqlAstBuilder.scala b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark31SqlAstBuilder.scala similarity index 98% rename from hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark312SqlAstBuilder.scala rename to hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark31SqlAstBuilder.scala index d92cceb9415ab..bee52ee33f4cf 100644 --- a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark312SqlAstBuilder.scala +++ b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark31SqlAstBuilder.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.SparkSqlAstBuilder // TODO: we should remove this file when we support datasourceV2 for hoodie on spark3.1x -class HoodieSpark312SqlAstBuilder(sparkSession: SparkSession) extends SparkSqlAstBuilder { +class HoodieSpark31SqlAstBuilder(sparkSession: SparkSession) extends SparkSqlAstBuilder { /** * Parse a [[AlterTableAlterColumnStatement]] command to alter a column's property. diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_1ExtendedSqlParser.scala b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_1ExtendedSqlParser.scala index 304e2984783e4..4b332dbc9e4df 100644 --- a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_1ExtendedSqlParser.scala +++ b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_1ExtendedSqlParser.scala @@ -20,10 +20,16 @@ package org.apache.spark.sql.parser import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{SparkSqlAstBuilder, SparkSqlParser} // TODO: we should remove this file when we support datasourceV2 for hoodie on spark3.1x -class HoodieSpark3_1ExtendedSqlParser(session: SparkSession, delegate: ParserInterface) extends SparkSqlParser with Logging { - override val astBuilder: SparkSqlAstBuilder = new HoodieSpark312SqlAstBuilder(session) +class HoodieSpark3_1ExtendedSqlParser(session: SparkSession, delegate: ParserInterface) extends SparkSqlParser + with HoodieExtendedParserInterface + with Logging { + + override val astBuilder: SparkSqlAstBuilder = new HoodieSpark31SqlAstBuilder(session) + + override def parseMultipartIdentifier(sqlText: String): Seq[String] = super[SparkSqlParser].parseMultipartIdentifier(sqlText) } diff --git a/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/HoodieSpark32CatalystExpressionUtils.scala b/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/HoodieSpark32CatalystExpressionUtils.scala index 52c8de6bf7b74..c802d38deda6e 100644 --- a/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/HoodieSpark32CatalystExpressionUtils.scala +++ b/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/HoodieSpark32CatalystExpressionUtils.scala @@ -23,6 +23,13 @@ import org.apache.spark.sql.types.DataType object HoodieSpark32CatalystExpressionUtils extends HoodieSpark3CatalystExpressionUtils { + override def matchCast(expr: Expression): Option[(Expression, DataType, Option[String])] = + expr match { + case Cast(child, dataType, timeZoneId, _) => Some((child, dataType, timeZoneId)) + case AnsiCast(child, dataType, timeZoneId) => Some((child, dataType, timeZoneId)) + case _ => None + } + override def tryMatchAttributeOrderingPreservingTransformation(expr: Expression): Option[AttributeReference] = { expr match { case OrderPreservingTransformation(attrRef) => Some(attrRef) diff --git a/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/HoodieSpark32CatalystPlanUtils.scala b/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/HoodieSpark32CatalystPlanUtils.scala index 0548fd47a4db8..604be90b3c80d 100644 --- a/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/HoodieSpark32CatalystPlanUtils.scala +++ b/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/HoodieSpark32CatalystPlanUtils.scala @@ -21,6 +21,9 @@ package org.apache.spark.sql import org.apache.hudi.HoodieSparkUtils import org.apache.hudi.common.util.ValidationUtils.checkArgument +import org.apache.spark.sql.catalyst.analysis.ResolvedTable +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connector.catalog.{Identifier, Table, TableCatalog} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Expression, ProjectionOverSchema} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, TimeTravelRelation} @@ -29,18 +32,11 @@ import org.apache.spark.sql.types.StructType object HoodieSpark32CatalystPlanUtils extends HoodieSpark3CatalystPlanUtils { - override def isRelationTimeTravel(plan: LogicalPlan): Boolean = { - plan.isInstanceOf[TimeTravelRelation] - } - - override def getRelationTimeTravel(plan: LogicalPlan): Option[(LogicalPlan, Option[Expression], Option[String])] = { + def unapplyResolvedTable(plan: LogicalPlan): Option[(TableCatalog, Identifier, Table)] = plan match { - case timeTravel: TimeTravelRelation => - Some((timeTravel.table, timeTravel.timestamp, timeTravel.version)) - case _ => - None + case ResolvedTable(catalog, identifier, table, _) => Some((catalog, identifier, table)) + case _ => None } - } override def projectOverSchema(schema: StructType, output: AttributeSet): ProjectionOverSchema = { val klass = classOf[ProjectionOverSchema] diff --git a/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/adapter/Spark3_2Adapter.scala b/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/adapter/Spark3_2Adapter.scala index 4a37890557e22..7a4c5c9d172db 100644 --- a/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/adapter/Spark3_2Adapter.scala +++ b/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/adapter/Spark3_2Adapter.scala @@ -25,11 +25,12 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.{Command, DeleteFromTable} +import org.apache.spark.sql.catalyst.util.METADATA_COL_ATTR_KEY import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, Spark32PlusHoodieParquetFileFormat} import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, PartitionedFile} import org.apache.spark.sql.hudi.analysis.TableValuedFunctions -import org.apache.spark.sql.parser.HoodieSpark3_2ExtendedSqlParser -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.parser.{HoodieExtendedParserInterface, HoodieSpark3_2ExtendedSqlParser} +import org.apache.spark.sql.types.{DataType, Metadata, MetadataBuilder, StructType} import org.apache.spark.sql.vectorized.ColumnarUtils /** @@ -39,23 +40,25 @@ class Spark3_2Adapter extends BaseSpark3Adapter { override def isColumnarBatchRow(r: InternalRow): Boolean = ColumnarUtils.isColumnarBatchRow(r) - override def getCatalogUtils: HoodieSpark3CatalogUtils = HoodieSpark32CatalogUtils + def createCatalystMetadataForMetaField: Metadata = + new MetadataBuilder() + .putBoolean(METADATA_COL_ATTR_KEY, value = true) + .build() - override def getCatalystExpressionUtils: HoodieCatalystExpressionUtils = HoodieSpark32CatalystExpressionUtils + override def getCatalogUtils: HoodieSpark3CatalogUtils = HoodieSpark32CatalogUtils override def getCatalystPlanUtils: HoodieCatalystPlansUtils = HoodieSpark32CatalystPlanUtils + override def getCatalystExpressionUtils: HoodieCatalystExpressionUtils = HoodieSpark32CatalystExpressionUtils + override def createAvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean): HoodieAvroSerializer = new HoodieSpark3_2AvroSerializer(rootCatalystType, rootAvroType, nullable) override def createAvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType): HoodieAvroDeserializer = new HoodieSpark3_2AvroDeserializer(rootAvroType, rootCatalystType) - override def createExtendedSparkParser: Option[(SparkSession, ParserInterface) => ParserInterface] = { - Some( - (spark: SparkSession, delegate: ParserInterface) => new HoodieSpark3_2ExtendedSqlParser(spark, delegate) - ) - } + override def createExtendedSparkParser(spark: SparkSession, delegate: ParserInterface): HoodieExtendedParserInterface = + new HoodieSpark3_2ExtendedSqlParser(spark, delegate) override def createHoodieParquetFileFormat(appendPartitionValues: Boolean): Option[ParquetFileFormat] = { Some(new Spark32PlusHoodieParquetFileFormat(appendPartitionValues)) @@ -69,13 +72,6 @@ class Spark3_2Adapter extends BaseSpark3Adapter { new Spark32HoodieFileScanRDD(sparkSession, readFunction, filePartitions) } - override def resolveDeleteFromTable(deleteFromTable: Command, - resolveExpression: Expression => Expression): DeleteFromTable = { - val deleteFromTableCommand = deleteFromTable.asInstanceOf[DeleteFromTable] - val resolvedCondition = deleteFromTableCommand.condition.map(resolveExpression) - DeleteFromTable(deleteFromTableCommand.table, resolvedCondition) - } - override def extractDeleteCondition(deleteFromTable: Command): Expression = { deleteFromTable.asInstanceOf[DeleteFromTable].condition.getOrElse(null) } diff --git a/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_2ExtendedSqlAstBuilder.scala b/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_2ExtendedSqlAstBuilder.scala index 89ab9bcf1e70e..196a77cb13af5 100644 --- a/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_2ExtendedSqlAstBuilder.scala +++ b/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_2ExtendedSqlAstBuilder.scala @@ -84,8 +84,7 @@ class HoodieSpark3_2ExtendedSqlAstBuilder(conf: SQLConf, delegate: ParserInterfa table.optionalMap(ctx.sample)(withSample) } - private def withTimeTravel( - ctx: TemporalClauseContext, plan: LogicalPlan): LogicalPlan = withOrigin(ctx) { + private def withTimeTravel(ctx: TemporalClauseContext, plan: LogicalPlan): LogicalPlan = withOrigin(ctx) { val v = ctx.version val version = if (ctx.INTEGER_VALUE != null) { Some(v.getText) diff --git a/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_2ExtendedSqlParser.scala b/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_2ExtendedSqlParser.scala index 9c5d4d8488d46..1f8d02340d909 100644 --- a/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_2ExtendedSqlParser.scala +++ b/hudi-spark-datasource/hudi-spark3.2.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_2ExtendedSqlParser.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.{AnalysisException, SparkSession} import java.util.Locale class HoodieSpark3_2ExtendedSqlParser(session: SparkSession, delegate: ParserInterface) - extends ParserInterface with Logging { + extends HoodieExtendedParserInterface with Logging { private lazy val conf = session.sqlContext.conf private lazy val builder = new HoodieSpark3_2ExtendedSqlAstBuilder(conf, delegate) diff --git a/hudi-spark-datasource/hudi-spark3.2plus-common/src/main/scala/org/apache/spark/sql/catalyst/plans/logcal/HoodieQuery.scala b/hudi-spark-datasource/hudi-spark3.2plus-common/src/main/scala/org/apache/spark/sql/catalyst/plans/logcal/HoodieQuery.scala index d7e546875bc04..78860a98d8ab9 100644 --- a/hudi-spark-datasource/hudi-spark3.2plus-common/src/main/scala/org/apache/spark/sql/catalyst/plans/logcal/HoodieQuery.scala +++ b/hudi-spark-datasource/hudi-spark3.2plus-common/src/main/scala/org/apache/spark/sql/catalyst/plans/logcal/HoodieQuery.scala @@ -17,14 +17,10 @@ package org.apache.spark.sql.catalyst.plans.logcal -import org.apache.hudi.DefaultSource -import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.hudi.common.util.ValidationUtils.checkState +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan} -import org.apache.spark.sql.execution.datasources.LogicalRelation - -import scala.collection.mutable - +import org.apache.spark.sql.catalyst.plans.logical.LeafNode case class HoodieQuery(args: Seq[Expression]) extends LeafNode { @@ -38,36 +34,20 @@ object HoodieQuery { val FUNC_NAME = "hudi_query"; - def resolve(spark: SparkSession, func: HoodieQuery): LogicalPlan = { - - val args = func.args + def parseOptions(exprs: Seq[Expression]): (String, Map[String, String]) = { + val args = exprs.map(_.eval().toString) - val identifier = spark.sessionState.sqlParser.parseTableIdentifier(args.head.eval().toString) - val catalogTable = spark.sessionState.catalog.getTableMetadata(identifier) + val (Seq(identifier, queryMode), remaining) = args.splitAt(2) - val options = mutable.Map("path" -> catalogTable.location.toString) ++ parseOptions(args.tail) + val opts = queryMode match { + case "read_optimized" | "snapshot" => + checkState(remaining.isEmpty, s"No additional args are expected in `$queryMode` mode") + Map("hoodie.datasource.query.type" -> queryMode) - val hoodieDataSource = new DefaultSource - val relation = hoodieDataSource.createRelation(spark.sqlContext, options.toMap) - new LogicalRelation( - relation, - relation.schema.toAttributes, - Some(catalogTable), - false - ) - } - - private def parseOptions(args: Seq[Expression]): Map[String, String] = { - val options = mutable.Map.empty[String, String] - val queryMode = args.head.eval().toString - val instants = args.tail.map(_.eval().toString) - queryMode match { - case "read_optimized" => - assert(instants.isEmpty, "No expressions have to be provided in read_optimized mode.") - options += ("hoodie.datasource.query.type" -> "read_optimized") case _ => - throw new AnalysisException("hudi_query doesn't support other query modes for now.") + throw new AnalysisException(s"'hudi_query' doesn't currently support `$queryMode`") } - options.toMap + + (identifier, opts) } } diff --git a/hudi-spark-datasource/hudi-spark3.2plus-common/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSpark32PlusAnalysis.scala b/hudi-spark-datasource/hudi-spark3.2plus-common/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSpark32PlusAnalysis.scala new file mode 100644 index 0000000000000..58841f19df27e --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.2plus-common/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSpark32PlusAnalysis.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hudi.analysis + +import org.apache.hudi.{DataSourceReadOptions, DefaultSource, SparkAdapterSupport} +import org.apache.spark.sql.HoodieSpark3CatalystPlanUtils.MatchResolvedTable +import org.apache.spark.sql.catalyst.analysis.UnresolvedPartitionSpec +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogUtils} +import org.apache.spark.sql.catalyst.plans.logcal.HoodieQuery +import org.apache.spark.sql.catalyst.plans.logcal.HoodieQuery.parseOptions +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper +import org.apache.spark.sql.connector.catalog.{Table, V1Table} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources.{DataSource, LogicalRelation} +import org.apache.spark.sql.hudi.ProvidesHoodieConfig +import org.apache.spark.sql.hudi.analysis.HoodieSpark32PlusAnalysis.{HoodieV1OrV2Table, ResolvesToHudiTable} +import org.apache.spark.sql.hudi.catalog.HoodieInternalV2Table +import org.apache.spark.sql.hudi.command.{AlterHoodieTableDropPartitionCommand, ShowHoodieTablePartitionsCommand, TruncateHoodieTableCommand} +import org.apache.spark.sql.{AnalysisException, SQLContext, SparkSession} + +/** + * NOTE: PLEASE READ CAREFULLY + * + * Since Hudi relations don't currently implement DS V2 Read API, we have to fallback to V1 here. + * Such fallback will have considerable performance impact, therefore it's only performed in cases + * where V2 API have to be used. Currently only such use-case is using of Schema Evolution feature + * + * Check out HUDI-4178 for more details + */ +case class HoodieDataSourceV2ToV1Fallback(sparkSession: SparkSession) extends Rule[LogicalPlan] + with ProvidesHoodieConfig { + + override def apply(plan: LogicalPlan): LogicalPlan = plan match { + // The only place we're avoiding fallback is in [[AlterTableCommand]]s since + // current implementation relies on DSv2 features + case _: AlterTableCommand => plan + + // NOTE: Unfortunately, [[InsertIntoStatement]] is implemented in a way that doesn't expose + // target relation as a child (even though there's no good reason for that) + case iis @ InsertIntoStatement(rv2 @ DataSourceV2Relation(v2Table: HoodieInternalV2Table, _, _, _, _), _, _, _, _, _) => + iis.copy(table = convertToV1(rv2, v2Table)) + + case _ => + plan.resolveOperatorsDown { + case rv2 @ DataSourceV2Relation(v2Table: HoodieInternalV2Table, _, _, _, _) => convertToV1(rv2, v2Table) + } + } + + private def convertToV1(rv2: DataSourceV2Relation, v2Table: HoodieInternalV2Table) = { + val output = rv2.output + val catalogTable = v2Table.catalogTable.map(_ => v2Table.v1Table) + val relation = new DefaultSource().createRelation(new SQLContext(sparkSession), + buildHoodieConfig(v2Table.hoodieCatalogTable), v2Table.hoodieCatalogTable.tableSchema) + + LogicalRelation(relation, output, catalogTable, isStreaming = false) + } +} + +/** + * Rule for resolve hoodie's extended syntax or rewrite some logical plan. + */ +case class HoodieSpark32PlusResolveReferences(spark: SparkSession) extends Rule[LogicalPlan] + with SparkAdapterSupport with ProvidesHoodieConfig { + + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { + case TimeTravelRelation(ResolvesToHudiTable(table), timestamp, version) => + if (timestamp.isEmpty && version.nonEmpty) { + throw new AnalysisException("Version expression is not supported for time travel") + } + + val pathOption = table.storage.locationUri.map("path" -> CatalogUtils.URIToString(_)) + val dataSource = + DataSource( + spark, + userSpecifiedSchema = if (table.schema.isEmpty) None else Some(table.schema), + partitionColumns = table.partitionColumnNames, + bucketSpec = table.bucketSpec, + className = table.provider.get, + options = table.storage.properties ++ pathOption ++ Map( + DataSourceReadOptions.TIME_TRAVEL_AS_OF_INSTANT.key -> timestamp.get.toString()), + catalogTable = Some(table)) + + val relation = dataSource.resolveRelation(checkFilesExist = false) + + LogicalRelation(relation, table) + + case q: HoodieQuery => + val (tableName, opts) = parseOptions(q.args) + + val tableId = spark.sessionState.sqlParser.parseTableIdentifier(tableName) + val catalogTable = spark.sessionState.catalog.getTableMetadata(tableId) + + val hoodieDataSource = new DefaultSource + val relation = hoodieDataSource.createRelation(spark.sqlContext, opts ++ Map("path" -> + catalogTable.location.toString)) + + LogicalRelation(relation, catalogTable) + } +} + +/** + * Rule replacing resolved Spark's commands (not working for Hudi tables out-of-the-box) with + * corresponding Hudi implementations + */ +case class HoodieSpark32PlusPostAnalysisRule(sparkSession: SparkSession) extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = { + plan match { + case ShowPartitions(MatchResolvedTable(_, id, HoodieV1OrV2Table(_)), specOpt, _) => + ShowHoodieTablePartitionsCommand( + id.asTableIdentifier, specOpt.map(s => s.asInstanceOf[UnresolvedPartitionSpec].spec)) + + // Rewrite TruncateTableCommand to TruncateHoodieTableCommand + case TruncateTable(MatchResolvedTable(_, id, HoodieV1OrV2Table(_))) => + TruncateHoodieTableCommand(id.asTableIdentifier, None) + + case TruncatePartition(MatchResolvedTable(_, id, HoodieV1OrV2Table(_)), partitionSpec: UnresolvedPartitionSpec) => + TruncateHoodieTableCommand(id.asTableIdentifier, Some(partitionSpec.spec)) + + case DropPartitions(MatchResolvedTable(_, id, HoodieV1OrV2Table(_)), specs, ifExists, purge) => + AlterHoodieTableDropPartitionCommand( + id.asTableIdentifier, + specs.seq.map(f => f.asInstanceOf[UnresolvedPartitionSpec]).map(s => s.spec), + ifExists, + purge, + retainData = true + ) + + case _ => plan + } + } +} + +object HoodieSpark32PlusAnalysis extends SparkAdapterSupport { + + private[sql] object HoodieV1OrV2Table { + def unapply(table: Table): Option[CatalogTable] = table match { + case V1Table(catalogTable) if sparkAdapter.isHoodieTable(catalogTable) => Some(catalogTable) + case v2: HoodieInternalV2Table => v2.catalogTable + case _ => None + } + } + + // TODO dedup w/ HoodieAnalysis + private[sql] object ResolvesToHudiTable { + def unapply(plan: LogicalPlan): Option[CatalogTable] = + sparkAdapter.resolveHoodieTable(plan) + } +} + + diff --git a/hudi-spark-datasource/hudi-spark3.2plus-common/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSpark3Analysis.scala b/hudi-spark-datasource/hudi-spark3.2plus-common/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSpark3Analysis.scala deleted file mode 100644 index abae14c70e809..0000000000000 --- a/hudi-spark-datasource/hudi-spark3.2plus-common/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieSpark3Analysis.scala +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hudi.analysis - -import org.apache.hudi.common.table.HoodieTableMetaClient -import org.apache.hudi.{DefaultSource, SparkAdapterSupport} -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{ResolvedTable, UnresolvedPartitionSpec} -import org.apache.spark.sql.catalyst.catalog.{CatalogTable, HoodieCatalogTable} -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute} -import org.apache.spark.sql.catalyst.plans.logcal.HoodieQuery -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper -import org.apache.spark.sql.connector.catalog.{Table, V1Table} -import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.execution.datasources.PreWriteCheck.failAnalysis -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.hudi.HoodieSqlCommonUtils.{castIfNeeded, getTableLocation, removeMetaFields, tableExistsInPath} -import org.apache.spark.sql.hudi.catalog.HoodieInternalV2Table -import org.apache.spark.sql.hudi.command.{AlterHoodieTableDropPartitionCommand, ShowHoodieTablePartitionsCommand, TruncateHoodieTableCommand} -import org.apache.spark.sql.hudi.{HoodieSqlCommonUtils, ProvidesHoodieConfig} -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{AnalysisException, SQLContext, SparkSession} - -import scala.collection.JavaConverters.mapAsJavaMapConverter - -/** - * NOTE: PLEASE READ CAREFULLY - * - * Since Hudi relations don't currently implement DS V2 Read API, we have to fallback to V1 here. - * Such fallback will have considerable performance impact, therefore it's only performed in cases - * where V2 API have to be used. Currently only such use-case is using of Schema Evolution feature - * - * Check out HUDI-4178 for more details - */ -class HoodieDataSourceV2ToV1Fallback(sparkSession: SparkSession) extends Rule[LogicalPlan] - with ProvidesHoodieConfig { - - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsDown { - case v2r @ DataSourceV2Relation(v2Table: HoodieInternalV2Table, _, _, _, _) => - val output = v2r.output - val catalogTable = v2Table.catalogTable.map(_ => v2Table.v1Table) - val relation = new DefaultSource().createRelation(new SQLContext(sparkSession), - buildHoodieConfig(v2Table.hoodieCatalogTable), v2Table.hoodieCatalogTable.tableSchema) - - LogicalRelation(relation, output, catalogTable, isStreaming = false) - } -} - -class HoodieSpark3Analysis(sparkSession: SparkSession) extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsDown { - case s @ InsertIntoStatement(r @ DataSourceV2Relation(v2Table: HoodieInternalV2Table, _, _, _, _), partitionSpec, _, _, _, _) - if s.query.resolved && needsSchemaAdjustment(s.query, v2Table.hoodieCatalogTable.table, partitionSpec, r.schema) => - val projection = resolveQueryColumnsByOrdinal(s.query, r.output) - if (projection != s.query) { - s.copy(query = projection) - } else { - s - } - - case query: HoodieQuery => - HoodieQuery.resolve(sparkSession, query) - } - - /** - * Need to adjust schema based on the query and relation schema, for example, - * if using insert into xx select 1, 2 here need to map to column names - */ - private def needsSchemaAdjustment(query: LogicalPlan, - table: CatalogTable, - partitionSpec: Map[String, Option[String]], - schema: StructType): Boolean = { - val output = query.output - val queryOutputWithoutMetaFields = removeMetaFields(output) - val hoodieCatalogTable = HoodieCatalogTable(sparkSession, table) - - val partitionFields = hoodieCatalogTable.partitionFields - val partitionSchema = hoodieCatalogTable.partitionSchema - val staticPartitionValues = partitionSpec.filter(p => p._2.isDefined).mapValues(_.get) - - assert(staticPartitionValues.isEmpty || - staticPartitionValues.size == partitionSchema.size, - s"Required partition columns is: ${partitionSchema.json}, Current static partitions " + - s"is: ${staticPartitionValues.mkString("," + "")}") - - assert(staticPartitionValues.size + queryOutputWithoutMetaFields.size - == hoodieCatalogTable.tableSchemaWithoutMetaFields.size, - s"Required select columns count: ${hoodieCatalogTable.tableSchemaWithoutMetaFields.size}, " + - s"Current select columns(including static partition column) count: " + - s"${staticPartitionValues.size + queryOutputWithoutMetaFields.size},columns: " + - s"(${(queryOutputWithoutMetaFields.map(_.name) ++ staticPartitionValues.keys).mkString(",")})") - - // static partition insert. - if (staticPartitionValues.nonEmpty) { - // drop partition fields in origin schema to align fields. - schema.dropWhile(p => partitionFields.contains(p.name)) - } - - val existingSchemaOutput = output.take(schema.length) - existingSchemaOutput.map(_.name) != schema.map(_.name) || - existingSchemaOutput.map(_.dataType) != schema.map(_.dataType) - } - - private def resolveQueryColumnsByOrdinal(query: LogicalPlan, - targetAttrs: Seq[Attribute]): LogicalPlan = { - // always add a Cast. it will be removed in the optimizer if it is unnecessary. - val project = query.output.zipWithIndex.map { case (attr, i) => - if (i < targetAttrs.length) { - val targetAttr = targetAttrs(i) - val castAttr = castIfNeeded(attr.withNullability(targetAttr.nullable), targetAttr.dataType, conf) - Alias(castAttr, targetAttr.name)() - } else { - attr - } - } - Project(project, query) - } -} - -/** - * Rule replacing resolved Spark's commands (not working for Hudi tables out-of-the-box) with - * corresponding Hudi implementations - */ -case class HoodieSpark3PostAnalysisRule(sparkSession: SparkSession) extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = { - plan match { - case ShowPartitions(ResolvedTable(_, id, HoodieV1OrV2Table(_), _), specOpt, _) => - ShowHoodieTablePartitionsCommand( - id.asTableIdentifier, specOpt.map(s => s.asInstanceOf[UnresolvedPartitionSpec].spec)) - - // Rewrite TruncateTableCommand to TruncateHoodieTableCommand - case TruncateTable(ResolvedTable(_, id, HoodieV1OrV2Table(_), _)) => - TruncateHoodieTableCommand(id.asTableIdentifier, None) - - case TruncatePartition(ResolvedTable(_, id, HoodieV1OrV2Table(_), _), partitionSpec: UnresolvedPartitionSpec) => - TruncateHoodieTableCommand(id.asTableIdentifier, Some(partitionSpec.spec)) - - case DropPartitions(ResolvedTable(_, id, HoodieV1OrV2Table(_), _), specs, ifExists, purge) => - AlterHoodieTableDropPartitionCommand( - id.asTableIdentifier, - specs.seq.map(f => f.asInstanceOf[UnresolvedPartitionSpec]).map(s => s.spec), - ifExists, - purge, - retainData = true - ) - - case _ => plan - } - } -} - -private[sql] object HoodieV1OrV2Table extends SparkAdapterSupport { - def unapply(table: Table): Option[CatalogTable] = table match { - case V1Table(catalogTable) if sparkAdapter.isHoodieTable(catalogTable) => Some(catalogTable) - case v2: HoodieInternalV2Table => v2.catalogTable - case _ => None - } -} - diff --git a/hudi-spark-datasource/hudi-spark3.2plus-common/src/main/scala/org/apache/spark/sql/hudi/catalog/HoodieCatalog.scala b/hudi-spark-datasource/hudi-spark3.2plus-common/src/main/scala/org/apache/spark/sql/hudi/catalog/HoodieCatalog.scala index 6d3610db21eda..bb2884382c23f 100644 --- a/hudi-spark-datasource/hudi-spark3.2plus-common/src/main/scala/org/apache/spark/sql/hudi/catalog/HoodieCatalog.scala +++ b/hudi-spark-datasource/hudi-spark3.2plus-common/src/main/scala/org/apache/spark/sql/hudi/catalog/HoodieCatalog.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, ColumnChan import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.hudi.analysis.HoodieV1OrV2Table +import org.apache.spark.sql.hudi.analysis.HoodieSpark32PlusAnalysis.HoodieV1OrV2Table import org.apache.spark.sql.hudi.command._ import org.apache.spark.sql.hudi.{HoodieSqlCommonUtils, ProvidesHoodieConfig} import org.apache.spark.sql.types.{StructField, StructType} diff --git a/hudi-spark-datasource/hudi-spark3.2plus-common/src/main/scala/org/apache/spark/sql/hudi/catalog/HoodieInternalV2Table.scala b/hudi-spark-datasource/hudi-spark3.2plus-common/src/main/scala/org/apache/spark/sql/hudi/catalog/HoodieInternalV2Table.scala index b41c7456b7125..2ec000ee9804a 100644 --- a/hudi-spark-datasource/hudi-spark3.2plus-common/src/main/scala/org/apache/spark/sql/hudi/catalog/HoodieInternalV2Table.scala +++ b/hudi-spark-datasource/hudi-spark3.2plus-common/src/main/scala/org/apache/spark/sql/hudi/catalog/HoodieInternalV2Table.scala @@ -21,7 +21,7 @@ import org.apache.hudi.common.table.{HoodieTableConfig, HoodieTableMetaClient} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogTable, HoodieCatalogTable} import org.apache.spark.sql.connector.catalog.TableCapability._ -import org.apache.spark.sql.connector.catalog._ +import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability, V1Table, V2TableWithV1Fallback} import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} import org.apache.spark.sql.connector.write._ import org.apache.spark.sql.hudi.ProvidesHoodieConfig @@ -85,8 +85,8 @@ case class HoodieInternalV2Table(spark: SparkSession, } private class HoodieV1WriteBuilder(writeOptions: CaseInsensitiveStringMap, - hoodieCatalogTable: HoodieCatalogTable, - spark: SparkSession) + hoodieCatalogTable: HoodieCatalogTable, + spark: SparkSession) extends SupportsTruncate with SupportsOverwrite with ProvidesHoodieConfig { private var overwriteTable = false @@ -106,7 +106,7 @@ private class HoodieV1WriteBuilder(writeOptions: CaseInsensitiveStringMap, override def toInsertableRelation: InsertableRelation = { new InsertableRelation { override def insert(data: DataFrame, overwrite: Boolean): Unit = { - alignOutputColumns(data).write.format("org.apache.hudi") + data.write.format("org.apache.hudi") .mode(SaveMode.Append) .options(buildHoodieConfig(hoodieCatalogTable) ++ buildHoodieInsertConfig(hoodieCatalogTable, spark, overwritePartition, overwriteTable, Map.empty, Map.empty)) @@ -115,9 +115,4 @@ private class HoodieV1WriteBuilder(writeOptions: CaseInsensitiveStringMap, } } } - - private def alignOutputColumns(data: DataFrame): DataFrame = { - val schema = hoodieCatalogTable.tableSchema - spark.createDataFrame(data.toJavaRDD, schema) - } } diff --git a/hudi-spark-datasource/hudi-spark3.2plus-common/src/main/scala/org/apache/spark/sql/hudi/logical/TimeTravelRelation.scala b/hudi-spark-datasource/hudi-spark3.2plus-common/src/main/scala/org/apache/spark/sql/hudi/logical/TimeTravelRelation.scala index f243a7a86174f..829971db0e30f 100644 --- a/hudi-spark-datasource/hudi-spark3.2plus-common/src/main/scala/org/apache/spark/sql/hudi/logical/TimeTravelRelation.scala +++ b/hudi-spark-datasource/hudi-spark3.2plus-common/src/main/scala/org/apache/spark/sql/hudi/logical/TimeTravelRelation.scala @@ -19,15 +19,15 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -case class TimeTravelRelation( - table: LogicalPlan, - timestamp: Option[Expression], - version: Option[String]) extends Command { - override def children: Seq[LogicalPlan] = Seq.empty - - override def output: Seq[Attribute] = Nil +case class TimeTravelRelation(relation: LogicalPlan, + timestamp: Option[Expression], + version: Option[String]) extends UnaryNode with HoodieUnaryLikeSham[LogicalPlan] { override lazy val resolved: Boolean = false - def withNewChildrenInternal(newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = this + override def output: Seq[Attribute] = Nil + override def child: LogicalPlan = relation + + override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = + copy(relation = newChild) } diff --git a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/HoodieSpark33CatalystExpressionUtils.scala b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/HoodieSpark33CatalystExpressionUtils.scala index d68c9f373fbc1..94bd089522e7e 100644 --- a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/HoodieSpark33CatalystExpressionUtils.scala +++ b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/HoodieSpark33CatalystExpressionUtils.scala @@ -23,6 +23,13 @@ import org.apache.spark.sql.types.DataType object HoodieSpark33CatalystExpressionUtils extends HoodieSpark3CatalystExpressionUtils { + override def matchCast(expr: Expression): Option[(Expression, DataType, Option[String])] = + expr match { + case Cast(child, dataType, timeZoneId, _) => Some((child, dataType, timeZoneId)) + case AnsiCast(child, dataType, timeZoneId) => Some((child, dataType, timeZoneId)) + case _ => None + } + override def tryMatchAttributeOrderingPreservingTransformation(expr: Expression): Option[AttributeReference] = { expr match { case OrderPreservingTransformation(attrRef) => Some(attrRef) diff --git a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/HoodieSpark33CatalystPlanUtils.scala b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/HoodieSpark33CatalystPlanUtils.scala index c3642544889ce..13ad1ae5a8606 100644 --- a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/HoodieSpark33CatalystPlanUtils.scala +++ b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/HoodieSpark33CatalystPlanUtils.scala @@ -18,6 +18,9 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.analysis.ResolvedTable +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connector.catalog.{Identifier, Table, TableCatalog} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Expression, ProjectionOverSchema} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, TimeTravelRelation} @@ -26,18 +29,11 @@ import org.apache.spark.sql.types.StructType object HoodieSpark33CatalystPlanUtils extends HoodieSpark3CatalystPlanUtils { - override def isRelationTimeTravel(plan: LogicalPlan): Boolean = { - plan.isInstanceOf[TimeTravelRelation] - } - - override def getRelationTimeTravel(plan: LogicalPlan): Option[(LogicalPlan, Option[Expression], Option[String])] = { + def unapplyResolvedTable(plan: LogicalPlan): Option[(TableCatalog, Identifier, Table)] = plan match { - case timeTravel: TimeTravelRelation => - Some((timeTravel.table, timeTravel.timestamp, timeTravel.version)) - case _ => - None + case ResolvedTable(catalog, identifier, table, _) => Some((catalog, identifier, table)) + case _ => None } - } override def projectOverSchema(schema: StructType, output: AttributeSet): ProjectionOverSchema = ProjectionOverSchema(schema, output) diff --git a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/adapter/Spark3_3Adapter.scala b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/adapter/Spark3_3Adapter.scala index 696c044223f2d..857147eb66c25 100644 --- a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/adapter/Spark3_3Adapter.scala +++ b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/adapter/Spark3_3Adapter.scala @@ -19,19 +19,19 @@ package org.apache.spark.sql.adapter import org.apache.avro.Schema import org.apache.hudi.Spark33HoodieFileScanRDD -import org.apache.spark.sql.SparkSessionExtensions +import org.apache.spark.sql._ import org.apache.spark.sql.avro._ +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, PartitionedFile} +import org.apache.spark.sql.catalyst.util.METADATA_COL_ATTR_KEY import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, Spark32PlusHoodieParquetFileFormat} +import org.apache.spark.sql.execution.datasources.{FilePartition, FileScanRDD, PartitionedFile} import org.apache.spark.sql.hudi.analysis.TableValuedFunctions -import org.apache.spark.sql.parser.HoodieSpark3_3ExtendedSqlParser -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.parser.{HoodieExtendedParserInterface, HoodieSpark3_3ExtendedSqlParser} +import org.apache.spark.sql.types.{DataType, Metadata, MetadataBuilder, StructType} import org.apache.spark.sql.vectorized.ColumnarBatchRow -import org.apache.spark.sql.{HoodieCatalystExpressionUtils, HoodieCatalystPlansUtils, HoodieSpark33CatalogUtils, HoodieSpark33CatalystExpressionUtils, HoodieSpark33CatalystPlanUtils, HoodieSpark3CatalogUtils, SparkSession} /** * Implementation of [[SparkAdapter]] for Spark 3.3.x branch @@ -40,23 +40,25 @@ class Spark3_3Adapter extends BaseSpark3Adapter { override def isColumnarBatchRow(r: InternalRow): Boolean = r.isInstanceOf[ColumnarBatchRow] - override def getCatalogUtils: HoodieSpark3CatalogUtils = HoodieSpark33CatalogUtils + def createCatalystMetadataForMetaField: Metadata = + new MetadataBuilder() + .putBoolean(METADATA_COL_ATTR_KEY, value = true) + .build() - override def getCatalystExpressionUtils: HoodieCatalystExpressionUtils = HoodieSpark33CatalystExpressionUtils + override def getCatalogUtils: HoodieSpark3CatalogUtils = HoodieSpark33CatalogUtils override def getCatalystPlanUtils: HoodieCatalystPlansUtils = HoodieSpark33CatalystPlanUtils + override def getCatalystExpressionUtils: HoodieCatalystExpressionUtils = HoodieSpark33CatalystExpressionUtils + override def createAvroSerializer(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean): HoodieAvroSerializer = new HoodieSpark3_3AvroSerializer(rootCatalystType, rootAvroType, nullable) override def createAvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType): HoodieAvroDeserializer = new HoodieSpark3_3AvroDeserializer(rootAvroType, rootCatalystType) - override def createExtendedSparkParser: Option[(SparkSession, ParserInterface) => ParserInterface] = { - Some( - (spark: SparkSession, delegate: ParserInterface) => new HoodieSpark3_3ExtendedSqlParser(spark, delegate) - ) - } + override def createExtendedSparkParser(spark: SparkSession, delegate: ParserInterface): HoodieExtendedParserInterface = + new HoodieSpark3_3ExtendedSqlParser(spark, delegate) override def createHoodieParquetFileFormat(appendPartitionValues: Boolean): Option[ParquetFileFormat] = { Some(new Spark32PlusHoodieParquetFileFormat(appendPartitionValues)) @@ -70,21 +72,10 @@ class Spark3_3Adapter extends BaseSpark3Adapter { new Spark33HoodieFileScanRDD(sparkSession, readFunction, filePartitions, readDataSchema, metadataColumns) } - override def resolveDeleteFromTable(deleteFromTable: Command, - resolveExpression: Expression => Expression): DeleteFromTable = { - val deleteFromTableCommand = deleteFromTable.asInstanceOf[DeleteFromTable] - DeleteFromTable(deleteFromTableCommand.table, resolveExpression(deleteFromTableCommand.condition)) - } - override def extractDeleteCondition(deleteFromTable: Command): Expression = { deleteFromTable.asInstanceOf[DeleteFromTable].condition } - override def getQueryParserFromExtendedSqlParser(session: SparkSession, delegate: ParserInterface, - sqlText: String): LogicalPlan = { - new HoodieSpark3_3ExtendedSqlParser(session, delegate).parseQuery(sqlText) - } - override def injectTableFunctions(extensions: SparkSessionExtensions): Unit = { TableValuedFunctions.funcs.foreach(extensions.injectTableFunction) } diff --git a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_3ExtendedSqlParser.scala b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_3ExtendedSqlParser.scala index 36b8bd3608eb2..4c59f56828f2d 100644 --- a/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_3ExtendedSqlParser.scala +++ b/hudi-spark-datasource/hudi-spark3.3.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_3ExtendedSqlParser.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.{AnalysisException, SparkSession} import java.util.Locale class HoodieSpark3_3ExtendedSqlParser(session: SparkSession, delegate: ParserInterface) - extends ParserInterface with Logging { + extends HoodieExtendedParserInterface with Logging { private lazy val conf = session.sqlContext.conf private lazy val builder = new HoodieSpark3_3ExtendedSqlAstBuilder(conf, delegate) @@ -56,9 +56,7 @@ class HoodieSpark3_3ExtendedSqlParser(session: SparkSession, delegate: ParserInt } } - // SPARK-37266 Added parseQuery to ParserInterface in Spark 3.3.0 - // Don't mark this as override for backward compatibility - def parseQuery(sqlText: String): LogicalPlan = delegate.parseQuery(sqlText) + override def parseQuery(sqlText: String): LogicalPlan = delegate.parseQuery(sqlText) override def parseExpression(sqlText: String): Expression = delegate.parseExpression(sqlText)