diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/InsertIntoHoodieTableCommand.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/InsertIntoHoodieTableCommand.scala index f058b47d782d5..2fc7fb017ad5e 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/InsertIntoHoodieTableCommand.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/InsertIntoHoodieTableCommand.scala @@ -20,13 +20,14 @@ package org.apache.spark.sql.hudi.command import org.apache.hudi.HoodieSparkSqlWriter import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.catalog.{CatalogTable, HoodieCatalogTable} -import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Literal} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.hudi.HoodieSqlCommonUtils._ import org.apache.spark.sql.hudi.ProvidesHoodieConfig import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{StructField, StructType} import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession} /** @@ -120,50 +121,45 @@ object InsertIntoHoodieTableCommand extends Logging with ProvidesHoodieConfig { val staticPartitionValues = insertPartitions.filter(p => p._2.isDefined).mapValues(_.get) assert(staticPartitionValues.isEmpty || - staticPartitionValues.size == targetPartitionSchema.size, - s"Required partition columns is: ${targetPartitionSchema.json}, Current static partitions " + + insertPartitions.size == targetPartitionSchema.size, + s"Required partition columns is: ${targetPartitionSchema.json}, Current input partitions " + s"is: ${staticPartitionValues.mkString("," + "")}") val queryOutputWithoutMetaFields = removeMetaFields(query.output) assert(staticPartitionValues.size + queryOutputWithoutMetaFields.size - == hoodieCatalogTable.tableSchemaWithoutMetaFields.size, + == hoodieCatalogTable.tableSchemaWithoutMetaFields.size, s"Required select columns count: ${hoodieCatalogTable.tableSchemaWithoutMetaFields.size}, " + s"Current select columns(including static partition column) count: " + s"${staticPartitionValues.size + queryOutputWithoutMetaFields.size},columns: " + s"(${(queryOutputWithoutMetaFields.map(_.name) ++ staticPartitionValues.keys).mkString(",")})") - val queryDataFieldsWithoutMetaFields = if (staticPartitionValues.isEmpty) { // insert dynamic partition - queryOutputWithoutMetaFields.dropRight(targetPartitionSchema.fields.length) - } else { // insert static partition - queryOutputWithoutMetaFields - } - // Align for the data fields of the query - val dataProjectsWithoutMetaFields = queryDataFieldsWithoutMetaFields.zip( - hoodieCatalogTable.dataSchemaWithoutMetaFields.fields).map { case (dataAttr, targetField) => - val castAttr = castIfNeeded(dataAttr.withNullability(targetField.nullable), - targetField.dataType, conf) - Alias(castAttr, targetField.name)() - } + val dataAndDynamicPartitionSchemaWithoutMetaFields = StructType( + hoodieCatalogTable.tableSchemaWithoutMetaFields.filterNot(f => staticPartitionValues.contains(f.name))) + val dataProjectsWithoutMetaFields = getTableFieldsAlias(queryOutputWithoutMetaFields, + dataAndDynamicPartitionSchemaWithoutMetaFields.fields, conf) - val partitionProjects = if (staticPartitionValues.isEmpty) { // insert dynamic partitions - // The partition attributes is followed the data attributes in the query - // So we init the partitionAttrPosition with the data schema size. - var partitionAttrPosition = hoodieCatalogTable.dataSchemaWithoutMetaFields.size - targetPartitionSchema.fields.map(f => { - val partitionAttr = queryOutputWithoutMetaFields(partitionAttrPosition) - partitionAttrPosition = partitionAttrPosition + 1 - val castAttr = castIfNeeded(partitionAttr.withNullability(f.nullable), f.dataType, conf) - Alias(castAttr, f.name)() - }) - } else { // insert static partitions - targetPartitionSchema.fields.map(f => { + val partitionProjects = targetPartitionSchema.fields.filter(f => staticPartitionValues.contains(f.name)) + .map(f => { val staticPartitionValue = staticPartitionValues.getOrElse(f.name, - s"Missing static partition value for: ${f.name}") + s"Missing static partition value for: ${f.name}") val castAttr = castIfNeeded(Literal.create(staticPartitionValue), f.dataType, conf) Alias(castAttr, f.name)() }) + + Project(dataProjectsWithoutMetaFields ++ partitionProjects, query) + } + + private def getTableFieldsAlias( + queryOutputWithoutMetaFields: Seq[Attribute], + schemaWithoutMetaFields: Seq[StructField], + conf: SQLConf): Seq[Alias] = { + queryOutputWithoutMetaFields.zip(schemaWithoutMetaFields).map { case (dataAttr, dataField) => + val targetFieldOption = if (dataAttr.name.startsWith("col")) None else + schemaWithoutMetaFields.find(_.name.equals(dataAttr.name)) + val targetField = if (targetFieldOption.isDefined) targetFieldOption.get else dataField + val castAttr = castIfNeeded(dataAttr.withNullability(targetField.nullable), + targetField.dataType, conf) + Alias(castAttr, targetField.name)() } - val alignedProjects = dataProjectsWithoutMetaFields ++ partitionProjects - Project(alignedProjects, query) } } diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestInsertTable.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestInsertTable.scala index 3f9066d0847f5..54dd45f3f5aaa 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestInsertTable.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestInsertTable.scala @@ -29,7 +29,7 @@ import java.io.File class TestInsertTable extends HoodieSparkSqlTestBase { - test("Test Insert Into") { + test("Test Insert Into with values") { withTempDir { tmp => val tableName = generateTableName // Create a partitioned table @@ -37,33 +37,173 @@ class TestInsertTable extends HoodieSparkSqlTestBase { s""" |create table $tableName ( | id int, + | dt string, | name string, | price double, - | ts long, - | dt string + | ts long |) using hudi | tblproperties (primaryKey = 'id') | partitioned by (dt) | location '${tmp.getCanonicalPath}' """.stripMargin) - // Insert into dynamic partition + + // Note: Do not write the field alias, the partition field must be placed last. spark.sql( s""" - | insert into $tableName - | select 1 as id, 'a1' as name, 10 as price, 1000 as ts, '2021-01-05' as dt - """.stripMargin) + | insert into $tableName values + | (1, 'a1', 10, 1000, "2021-01-05"), + | (2, 'a2', 20, 2000, "2021-01-06"), + | (3, 'a3', 30, 3000, "2021-01-07") + """.stripMargin) + checkAnswer(s"select id, name, price, ts, dt from $tableName")( - Seq(1, "a1", 10.0, 1000, "2021-01-05") + Seq(1, "a1", 10.0, 1000, "2021-01-05"), + Seq(2, "a2", 20.0, 2000, "2021-01-06"), + Seq(3, "a3", 30.0, 3000, "2021-01-07") ) + } + } + + test("Test Insert Into with static partition") { + withTempDir { tmp => + val tableName = generateTableName + // Create a partitioned table + spark.sql( + s""" + |create table $tableName ( + | id int, + | dt string, + | name string, + | price double, + | ts long + |) using hudi + | tblproperties (primaryKey = 'id') + | partitioned by (dt) + | location '${tmp.getCanonicalPath}' + """.stripMargin) // Insert into static partition spark.sql( s""" | insert into $tableName partition(dt = '2021-01-05') - | select 2 as id, 'a2' as name, 10 as price, 1000 as ts + | select 1 as id, 'a1' as name, 10 as price, 1000 as ts + """.stripMargin) + + spark.sql( + s""" + | insert into $tableName partition(dt = '2021-01-06') + | select 20 as price, 2000 as ts, 2 as id, 'a2' as name + """.stripMargin) + + // Note: Do not write the field alias, the partition field must be placed last. + spark.sql( + s""" + | insert into $tableName + | select 3, 'a3', 30, 3000, '2021-01-07' """.stripMargin) + checkAnswer(s"select id, name, price, ts, dt from $tableName")( Seq(1, "a1", 10.0, 1000, "2021-01-05"), - Seq(2, "a2", 10.0, 1000, "2021-01-05") + Seq(2, "a2", 20.0, 2000, "2021-01-06"), + Seq(3, "a3", 30.0, 3000, "2021-01-07") + ) + } + } + + test("Test Insert Into with dynamic partition") { + withTempDir { tmp => + val tableName = generateTableName + // Create a partitioned table + spark.sql( + s""" + |create table $tableName ( + | id int, + | dt string, + | name string, + | price double, + | ts long + |) using hudi + | tblproperties (primaryKey = 'id') + | partitioned by (dt) + | location '${tmp.getCanonicalPath}' + """.stripMargin) + + // Insert into dynamic partition + spark.sql( + s""" + | insert into $tableName partition(dt) + | select 1 as id, '2021-01-05' as dt, 'a1' as name, 10 as price, 1000 as ts + """.stripMargin) + + spark.sql( + s""" + | insert into $tableName + | select 2 as id, 'a2' as name, 20 as price, 2000 as ts, '2021-01-06' as dt + """.stripMargin) + + // Note: Do not write the field alias, the partition field must be placed last. + spark.sql( + s""" + | insert into $tableName + | select 3, 'a3', 30, 3000, '2021-01-07' + """.stripMargin) + + checkAnswer(s"select id, name, price, ts, dt from $tableName")( + Seq(1, "a1", 10.0, 1000, "2021-01-05"), + Seq(2, "a2", 20.0, 2000, "2021-01-06"), + Seq(3, "a3", 30.0, 3000, "2021-01-07") + ) + } + } + + test("Test Insert Into with multi partition") { + withTempDir { tmp => + val tableName = generateTableName + // Create a partitioned table + spark.sql( + s""" + |create table $tableName ( + | id int, + | dt string, + | name string, + | price double, + | ht string, + | ts long + |) using hudi + | tblproperties (primaryKey = 'id') + | partitioned by (dt, ht) + | location '${tmp.getCanonicalPath}' + """.stripMargin) + spark.sql( + s""" + | insert into $tableName partition(dt, ht) + | select 1 as id, 'a1' as name, 10 as price,'20210101' as dt, 1000 as ts, '01' as ht + """.stripMargin) + + // Insert into static partition and dynamic partition + spark.sql( + s""" + | insert into $tableName partition(dt = '20210102', ht) + | select 2 as id, 'a2' as name, 20 as price, 2000 as ts, '02' as ht + """.stripMargin) + + spark.sql( + s""" + | insert into $tableName partition(dt, ht = '03') + | select 3 as id, 'a3' as name, 30 as price, 3000 as ts, '20210103' as dt + """.stripMargin) + + // Note: Do not write the field alias, the partition field must be placed last. + spark.sql( + s""" + | insert into $tableName + | select 4, 'a4', 40, 4000, '20210104', '04' + """.stripMargin) + + checkAnswer(s"select id, name, price, ts, dt, ht from $tableName")( + Seq(1, "a1", 10.0, 1000, "20210101", "01"), + Seq(2, "a2", 20.0, 2000, "20210102", "02"), + Seq(3, "a3", 30.0, 3000, "20210103", "03"), + Seq(4, "a4", 40.0, 4000, "20210104", "04") ) } }