Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.sql.hudi
import scala.collection.JavaConverters._
import java.net.URI
import java.util.Locale

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hudi.SparkAdapterSupport
Expand All @@ -30,7 +29,7 @@ import org.apache.spark.sql.{Column, DataFrame, SparkSession}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType}
import org.apache.spark.sql.catalyst.expressions.{And, Cast, Expression, Literal}
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, Cast, Expression, Literal}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -106,6 +105,10 @@ object HoodieSqlUtils extends SparkAdapterSupport {
}
}

def removeMetaFields(attrs: Seq[Attribute]): Seq[Attribute] = {
attrs.filterNot(attr => isMetaField(attr.name))
}

/**
* Get the table location.
* @param tableId
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ object InsertIntoHoodieTableCommand {
s"Required partition columns is: ${targetPartitionSchema.json}, Current static partitions " +
s"is: ${staticPartitionValues.mkString("," + "")}")

assert(staticPartitionValues.size + query.output.size == table.schema.size,
s"Required select columns count: ${removeMetaFields(table.schema).size}, " +
s"Current select columns(including static partition column) count: " +
s"${staticPartitionValues.size + removeMetaFields(query.output).size},columns: " +
s"(${(removeMetaFields(query.output).map(_.name) ++ staticPartitionValues.keys).mkString(",")})")
val queryDataFields = if (staticPartitionValues.isEmpty) { // insert dynamic partition
query.output.dropRight(targetPartitionSchema.fields.length)
} else { // insert static partition
Expand Down Expand Up @@ -156,7 +161,7 @@ object InsertIntoHoodieTableCommand {
targetPartitionSchema.fields.map(f => {
val staticPartitionValue = staticPartitionValues.getOrElse(f.name,
s"Missing static partition value for: ${f.name}")
val castAttr = Literal.create(staticPartitionValue, f.dataType)
val castAttr = castIfNeeded(Literal.create(staticPartitionValue), f.dataType, conf)
Alias(castAttr, f.name)()
})
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,13 @@ class TestHoodieSqlBase extends FunSuite with BeforeAndAfterAll {
protected def checkAnswer(sql: String)(expects: Seq[Any]*): Unit = {
assertResult(expects.map(row => Row(row: _*)).toArray)(spark.sql(sql).collect())
}

protected def checkException(sql: String)(errorMsg: String): Unit = {
try {
spark.sql(sql)
} catch {
case e: Throwable =>
assertResult(errorMsg)(e.getMessage)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -220,4 +220,68 @@ class TestInsertTable extends TestHoodieSqlBase {
)
}
}

test("Test Different Type of Partition Column") {
withTempDir { tmp =>
val typeAndValue = Seq(
("string", "'1000'"),
("int", 1000),
("bigint", 10000),
("timestamp", "'2021-05-20 00:00:00'"),
("date", "'2021-05-20'")
)
typeAndValue.foreach { case (partitionType, partitionValue) =>
val tableName = generateTableName
// Create table
spark.sql(
s"""
|create table $tableName (
| id int,
| name string,
| price double,
| dt $partitionType
|) using hudi
| partitioned by (dt)
| location '${tmp.getCanonicalPath}/$tableName'
""".stripMargin)

spark.sql(s"insert into $tableName partition(dt = $partitionValue) select 1, 'a1', 10")
spark.sql(s"insert into $tableName select 2, 'a2', 10, $partitionValue")
checkAnswer(s"select id, name, price, cast(dt as string) from $tableName order by id")(
Seq(1, "a1", 10, removeQuotes(partitionValue).toString),
Seq(2, "a2", 10, removeQuotes(partitionValue).toString)
)
}
}
}

test("Test Insert Exception") {
val tableName = generateTableName
spark.sql(
s"""
|create table $tableName (
| id int,
| name string,
| price double,
| dt string
|) using hudi
| partitioned by (dt)
""".stripMargin)
checkException(s"insert into $tableName partition(dt = '2021-06-20')" +
s" select 1, 'a1', 10, '2021-06-20'") (
"assertion failed: Required select columns count: 4, Current select columns(including static partition column)" +
" count: 5,columns: (1,a1,10,2021-06-20,dt)"
)
checkException(s"insert into $tableName select 1, 'a1', 10")(
"assertion failed: Required select columns count: 4, Current select columns(including static partition column)" +
" count: 3,columns: (1,a1,10)"
)
}

private def removeQuotes(value: Any): Any = {
value match {
case s: String => s.stripPrefix("'").stripSuffix("'")
case _=> value
}
}
}