From 8a3b61d2c4a857dac1d47d6a003ddf7510865d09 Mon Sep 17 00:00:00 2001 From: John Zhuge Date: Mon, 3 Jun 2019 23:52:34 -0700 Subject: [PATCH 1/5] [SPARK-27845][SQL][WIP] DataSourceV2: InsertTable TODO: - DataFrameWriter.insertInto --- .../spark/sql/catalyst/parser/SqlBase.g4 | 4 +- .../spark/sql/catalyst/dsl/package.scala | 13 +- .../sql/catalyst/parser/AstBuilder.scala | 28 ++-- .../logical/sql/InsertTableStatement.scala | 50 +++++++ .../sql/catalyst/parser/DDLParserSuite.scala | 58 ++++++++ .../sql/catalyst/parser/PlanParserSuite.scala | 18 ++- .../datasources/DataSourceResolution.scala | 98 ++++++++++++- .../sql/sources/v2/DataSourceV2SQLSuite.scala | 111 ++++++++++++++- .../sources/v2/TestInMemoryTableCatalog.scala | 130 +++++++++++++----- 9 files changed, 443 insertions(+), 67 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/InsertTableStatement.scala diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 0a142c29a16f..48739d289f7a 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -294,8 +294,8 @@ query ; insertInto - : INSERT OVERWRITE TABLE tableIdentifier (partitionSpec (IF NOT EXISTS)?)? #insertOverwriteTable - | INSERT INTO TABLE? tableIdentifier partitionSpec? #insertIntoTable + : INSERT OVERWRITE TABLE multipartIdentifier (partitionSpec (IF NOT EXISTS)?)? #insertOverwriteTable + | INSERT INTO TABLE? multipartIdentifier partitionSpec? #insertIntoTable | INSERT OVERWRITE LOCAL? DIRECTORY path=STRING rowFormat? createFileFormat? #insertOverwriteHiveDir | INSERT OVERWRITE LOCAL? DIRECTORY (path=STRING)? tableProvider (OPTIONS options=tablePropertyList)? #insertOverwriteDir ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 54fc1f9abb08..fe539348e9ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.sql._ import org.apache.spark.sql.types._ /** @@ -379,10 +380,14 @@ package object dsl { Generate(generator, unrequiredChildIndex, outer, alias, outputNames.map(UnresolvedAttribute(_)), logicalPlan) - def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = - InsertIntoTable( - analysis.UnresolvedRelation(TableIdentifier(tableName)), - Map.empty, logicalPlan, overwrite, ifPartitionNotExists = false) + def insertInto(tableName: String): LogicalPlan = insertInto(table(tableName)) + + def insertInto( + table: LogicalPlan, + overwrite: Boolean = false, + partition: Map[String, Option[String]] = Map.empty, + ifPartitionNotExists: Boolean = false): LogicalPlan = + InsertTableStatement(table, logicalPlan, overwrite, partition, ifPartitionNotExists) def as(alias: String): LogicalPlan = SubqueryAlias(alias, logicalPlan) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index a7a3b96ba726..90cfdce67930 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableAlterColumnStatement, AlterTableDropColumnsStatement, AlterTableRenameColumnStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, AlterViewSetPropertiesStatement, AlterViewUnsetPropertiesStatement, CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement, QualifiedColType, ReplaceTableAsSelectStatement, ReplaceTableStatement} +import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableAlterColumnStatement, AlterTableDropColumnsStatement, AlterTableRenameColumnStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, AlterViewSetPropertiesStatement, AlterViewUnsetPropertiesStatement, CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement, InsertTableStatement, QualifiedColType, ReplaceTableAsSelectStatement, ReplaceTableStatement} import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -239,9 +239,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging /** * Parameters used for writing query to a table: - * (tableIdentifier, partitionKeys, exists). + * (multipartIdentifier, partitionKeys, ifPartitionNotExists). */ - type InsertTableParams = (TableIdentifier, Map[String, Option[String]], Boolean) + type InsertTableParams = (Seq[String], Map[String, Option[String]], Boolean) /** * Parameters used for writing query to a directory: (isLocal, CatalogStorageFormat, provider). @@ -263,11 +263,21 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging query: LogicalPlan): LogicalPlan = withOrigin(ctx) { ctx match { case table: InsertIntoTableContext => - val (tableIdent, partitionKeys, exists) = visitInsertIntoTable(table) - InsertIntoTable(UnresolvedRelation(tableIdent), partitionKeys, query, false, exists) + val (tableIdent, partition, ifPartitionNotExists) = visitInsertIntoTable(table) + InsertTableStatement( + UnresolvedRelation(tableIdent), + query, + overwrite = false, + partition, + ifPartitionNotExists) case table: InsertOverwriteTableContext => - val (tableIdent, partitionKeys, exists) = visitInsertOverwriteTable(table) - InsertIntoTable(UnresolvedRelation(tableIdent), partitionKeys, query, true, exists) + val (tableIdent, partition, ifPartitionNotExists) = visitInsertOverwriteTable(table) + InsertTableStatement( + UnresolvedRelation(tableIdent), + query, + overwrite = true, + partition, + ifPartitionNotExists) case dir: InsertOverwriteDirContext => val (isLocal, storage, provider) = visitInsertOverwriteDir(dir) InsertIntoDir(isLocal, storage, provider, query, overwrite = true) @@ -284,7 +294,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging */ override def visitInsertIntoTable( ctx: InsertIntoTableContext): InsertTableParams = withOrigin(ctx) { - val tableIdent = visitTableIdentifier(ctx.tableIdentifier) + val tableIdent = visitMultipartIdentifier(ctx.multipartIdentifier) val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty) (tableIdent, partitionKeys, false) @@ -296,7 +306,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging override def visitInsertOverwriteTable( ctx: InsertOverwriteTableContext): InsertTableParams = withOrigin(ctx) { assert(ctx.OVERWRITE() != null) - val tableIdent = visitTableIdentifier(ctx.tableIdentifier) + val tableIdent = visitMultipartIdentifier(ctx.multipartIdentifier) val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty) val dynamicPartitionKeys: Map[String, Option[String]] = partitionKeys.filter(_._2.isEmpty) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/InsertTableStatement.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/InsertTableStatement.scala new file mode 100644 index 000000000000..b06d64c560dd --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/InsertTableStatement.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.sql + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +/** + * An INSERT TABLE statement, as parsed from SQL. + * + * @param table the logical plan representing the table. + * @param query the logical plan representing data to write to. + * @param overwrite overwrite existing table or partitions. + * @param partition a map from the partition key to the partition value (optional). + * If the value is missing, dynamic partition insert will be performed. + * As an example, `INSERT INTO tbl PARTITION (a=1, b=2) AS` would have + * Map('a' -> Some('1'), 'b' -> Some('2')), + * and `INSERT INTO tbl PARTITION (a=1, b) AS ...` + * would have Map('a' -> Some('1'), 'b' -> None). + * @param ifPartitionNotExists If true, only write if the partition does not exist. + * Only valid for static partitions. + */ +case class InsertTableStatement( + table: LogicalPlan, + query: LogicalPlan, + overwrite: Boolean, + partition: Map[String, Option[String]], + ifPartitionNotExists: Boolean) extends ParsedStatement { + + // IF NOT EXISTS is only valid in INSERT OVERWRITE + assert(overwrite || !ifPartitionNotExists) + // IF NOT EXISTS is only valid in static partitions + assert(partition.values.forall(_.nonEmpty) || !ifPartitionNotExists) + + override def children: Seq[LogicalPlan] = query :: Nil +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index dd84170e2620..4873e2ee403f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -22,6 +22,7 @@ import java.util.Locale import org.apache.spark.sql.catalog.v2.expressions.{ApplyTransform, BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform} import org.apache.spark.sql.catalyst.analysis.AnalysisTest import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableAlterColumnStatement, AlterTableDropColumnsStatement, AlterTableRenameColumnStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, AlterViewSetPropertiesStatement, AlterViewUnsetPropertiesStatement, CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement, QualifiedColType, ReplaceTableAsSelectStatement, ReplaceTableStatement} import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType, TimestampType} @@ -616,6 +617,63 @@ class DDLParserSuite extends AnalysisTest { } } + test("insert table: append") { + parseCompare("INSERT INTO TABLE testcat.ns1.ns2.tbl TABLE source", + table("source").insertInto(table("testcat", "ns1", "ns2", "tbl"))) + } + + test("insert table: append from another catalog") { + parseCompare("INSERT INTO TABLE testcat.ns1.ns2.tbl TABLE testcat2.db.tbl", + table("testcat2", "db", "tbl").insertInto(table("testcat", "ns1", "ns2", "tbl"))) + } + + test("insert table: append with partition") { + parseCompare( + """ + |INSERT INTO testcat.ns1.ns2.tbl + |PARTITION (p1 = 3, p2) + |TABLE source + """.stripMargin, + table("source") + .insertInto( + table("testcat", "ns1", "ns2", "tbl"), + partition = Map("p1" -> Some("3"), "p2" -> None))) + } + + test("insert table: overwrite") { + parseCompare("INSERT OVERWRITE TABLE testcat.ns1.ns2.tbl TABLE source", + table("source").insertInto(table("testcat", "ns1", "ns2", "tbl"), overwrite = true)) + } + + test("insert table: overwrite with partition") { + parseCompare( + """ + |INSERT OVERWRITE TABLE testcat.ns1.ns2.tbl + |PARTITION (p1 = 3, p2) + |TABLE source + """.stripMargin, + table("source") + .insertInto( + table("testcat", "ns1", "ns2", "tbl"), + overwrite = true, + partition = Map("p1" -> Some("3"), "p2" -> None))) + } + + test("insert table: overwrite with partition if not exists") { + parseCompare( + """ + |INSERT OVERWRITE TABLE testcat.ns1.ns2.tbl + |PARTITION (p1 = 3) IF NOT EXISTS + |TABLE source + """.stripMargin, + table("source") + .insertInto( + table("testcat", "ns1", "ns2", "tbl"), + overwrite = true, + partition = Map("p1" -> Some("3")), + ifPartitionNotExists = true)) + } + private case class TableSpec( name: Seq[String], schema: Option[StructType], diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index fb245eef5e4b..4ac2480760fd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -190,7 +190,7 @@ class PlanParserSuite extends AnalysisTest { partition: Map[String, Option[String]], overwrite: Boolean = false, ifPartitionNotExists: Boolean = false): LogicalPlan = - InsertIntoTable(table("s"), partition, plan, overwrite, ifPartitionNotExists) + plan.insertInto(table("s"), overwrite, partition, ifPartitionNotExists) // Single inserts assertEqual(s"insert overwrite table s $sql", @@ -205,10 +205,7 @@ class PlanParserSuite extends AnalysisTest { // Multi insert val plan2 = table("t").where('x > 5).select(star()) assertEqual("from t insert into s select * limit 1 insert into u select * where x > 5", - InsertIntoTable( - table("s"), Map.empty, plan.limit(1), false, ifPartitionNotExists = false).union( - InsertIntoTable( - table("u"), Map.empty, plan2, false, ifPartitionNotExists = false))) + plan.limit(1).insertInto("s").union(plan2.insertInto("u"))) } test ("insert with if not exists") { @@ -619,11 +616,12 @@ class PlanParserSuite extends AnalysisTest { comparePlans( parsePlan( "INSERT INTO s SELECT /*+ REPARTITION(100), COALESCE(500), COALESCE(10) */ * FROM t"), - InsertIntoTable(table("s"), Map.empty, - UnresolvedHint("REPARTITION", Seq(Literal(100)), - UnresolvedHint("COALESCE", Seq(Literal(500)), - UnresolvedHint("COALESCE", Seq(Literal(10)), - table("t").select(star())))), overwrite = false, ifPartitionNotExists = false)) + table("t") + .select(star()) + .hint("COALESCE", Literal(10)) + .hint("COALESCE", Literal(500)) + .hint("REPARTITION", Literal(100)) + .insertInto("s")) comparePlans( parsePlan("SELECT /*+ BROADCASTJOIN(u), REPARTITION(100) */ * FROM t"), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala index 8685d2f7a856..e06ac8461e52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala @@ -25,14 +25,16 @@ import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalog.v2.{CatalogPlugin, Identifier, LookupCatalog, TableCatalog} import org.apache.spark.sql.catalog.v2.expressions.Transform import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.CastSupport +import org.apache.spark.sql.catalyst.analysis.{CastSupport, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTableType, CatalogUtils, UnresolvedCatalogRelation} -import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, CreateV2Table, DropTable, LogicalPlan, ReplaceTable, ReplaceTableAsSelect} -import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, AlterViewSetPropertiesStatement, AlterViewUnsetPropertiesStatement, CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement, QualifiedColType, ReplaceTableAsSelectStatement, ReplaceTableStatement} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Cast, EqualTo, Literal} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, CreateV2Table, DropTable, InsertIntoTable, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Project, ReplaceTable, ReplaceTableAsSelect} +import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, AlterViewSetPropertiesStatement, AlterViewUnsetPropertiesStatement, CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement, InsertTableStatement, QualifiedColType, ReplaceTableAsSelectStatement, ReplaceTableStatement} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.command.{AlterTableAddColumnsCommand, AlterTableSetLocationCommand, AlterTableSetPropertiesCommand, AlterTableUnsetPropertiesCommand, DropTableCommand} import org.apache.spark.sql.execution.datasources.v2.{CatalogTableAsV2, DataSourceV2Relation} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.spark.sql.sources.v2.TableProvider import org.apache.spark.sql.types.{HIVE_TYPE_STRING, HiveStringType, MetadataBuilder, StructField, StructType} @@ -42,6 +44,7 @@ case class DataSourceResolution( extends Rule[LogicalPlan] with CastSupport { import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ + import org.apache.spark.sql.catalog.v2.utils.CatalogV2Util._ import lookup._ lazy val v2SessionCatalog: CatalogPlugin = lookup.sessionCatalog @@ -162,6 +165,95 @@ case class DataSourceResolution( case DataSourceV2Relation(CatalogTableAsV2(catalogTable), _, _) => UnresolvedCatalogRelation(catalogTable) + + case i @ InsertTableStatement(UnresolvedRelation(CatalogObjectIdentifier(Some(catalog), ident)), + _, _, _, _) if i.query.resolved => + loadTable(catalog, ident) + .map(DataSourceV2Relation.create) + .map(table => { + // ifPartitionNotExists is append with validation, but validation is not supported + if (i.ifPartitionNotExists) { + throw new AnalysisException( + s"Cannot write, IF NOT EXISTS is not supported for table: ${table.table.name}") + } + + val staticPartitions = i.partition.filter(_._2.isDefined).mapValues(_.get) + + val resolver = conf.resolver + + // add any static value as a literal column + val staticPartitionProjectList = { + // check that the data column counts match + val numColumns = table.output.size + if (numColumns > staticPartitions.size + i.query.output.size) { + throw new AnalysisException(s"Cannot write: too many columns") + } else if (numColumns < staticPartitions.size + i.query.output.size) { + throw new AnalysisException(s"Cannot write: not enough columns") + } + + val staticNames = staticPartitions.keySet + + // for each static name, find the column name it will replace and check for unknowns. + val outputNameToStaticName = staticNames.map(staticName => + table.output.find(col => resolver(col.name, staticName)) match { + case Some(attr) => + attr.name -> staticName + case _ => + throw new AnalysisException( + s"Cannot add static value for unknown column: $staticName") + }).toMap + + // for each output column, add the static value as a literal + // or use the next input column + val queryColumns = i.query.output.iterator + table.output.map { col => + outputNameToStaticName.get(col.name).flatMap(staticPartitions.get) match { + case Some(staticValue) => + Alias(Cast(Literal(staticValue), col.dataType), col.name)() + case _ => + queryColumns.next + } + } + } + + val dynamicPartitionOverwrite = table.table.partitioning.size > 0 && + staticPartitions.size < table.table.partitioning.size && + conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC + + val query = + if (staticPartitions.isEmpty) { + i.query + } else { + Project(staticPartitionProjectList, i.query) + } + + val deleteExpr = + if (staticPartitions.isEmpty) { + Literal(true) + } else { + staticPartitions.map { case (name, value) => + query.output.find(col => resolver(col.name, name)) match { + case Some(attr) => + EqualTo(attr, Cast(Literal(value), attr.dataType)) + case None => + throw new AnalysisException(s"Unknown static partition column: $name") + } + }.toSeq.reduce(And) + } + + if (!i.overwrite) { + AppendData.byPosition(table, query) + } else if (dynamicPartitionOverwrite) { + OverwritePartitionsDynamic.byPosition(table, query) + } else { + OverwriteByExpression.byPosition(table, query, deleteExpr) + } + }) + .getOrElse(i) + + case i @ InsertTableStatement(UnresolvedRelation(AsTableIdentifier(_)), _, _, _, _) + if i.query.resolved => + InsertIntoTable(i.table, i.partition, i.query, i.overwrite, i.ifPartitionNotExists) } object V1WriteProvider { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala index c173bdb95370..e274ca72ff73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalog.v2.Identifier import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2 -import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG +import org.apache.spark.sql.internal.SQLConf.{PARTITION_OVERWRITE_MODE, PartitionOverwriteMode, V2_SESSION_CATALOG} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{ArrayType, DoubleType, IntegerType, LongType, MapType, StringType, StructField, StructType, TimestampType} @@ -1349,4 +1349,113 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAn assert(updated.properties == Map("provider" -> "foo").asJava) } } + + test("InsertTable: append") { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo") + sql(s"INSERT INTO $t1 SELECT id, data FROM source") + checkAnswer(spark.table(t1), spark.table("source")) + } + } + + test("InsertTable: append - across catalog") { + val t1 = "testcat.ns1.ns2.tbl" + val t2 = "testcat2.db.tbl" + withTable(t1, t2) { + sql(s"CREATE TABLE $t1 USING foo AS TABLE source") + sql(s"CREATE TABLE $t2 (id bigint, data string) USING foo") + sql(s"INSERT INTO $t2 SELECT * FROM $t1") + checkAnswer(spark.table(t2), spark.table("source")) + } + } + + test("InsertTable: append partitioned table - dynamic clause") { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)") + sql(s"INSERT INTO TABLE $t1 TABLE source") + checkAnswer(spark.table(t1), spark.table("source")) + } + } + + test("InsertTable: append partitioned table - static clause") { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)") + sql(s"INSERT INTO $t1 PARTITION (id = 23) SELECT data FROM source") + checkAnswer(spark.table(t1), sql("SELECT 23, data FROM source")) + } + } + + test("InsertTable: overwrite non-partitioned table") { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 USING foo AS TABLE source") + sql(s"INSERT OVERWRITE TABLE $t1 TABLE source2") + checkAnswer(spark.table(t1), spark.table("source2")) + } + } + + test("InsertTable: overwrite - dynamic clause - static mode") { + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.STATIC.toString) { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy'), (4L, 'keep')") + sql(s"INSERT OVERWRITE TABLE $t1 TABLE source") + checkAnswer(spark.table(t1), spark.table("source")) + } + } + } + + test("InsertTable: overwrite - dynamic clause - dynamic mode") { + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy'), (4L, 'keep')") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id) TABLE source") + checkAnswer(spark.table(t1), + spark.table("source").union(sql("SELECT 4L, 'keep'"))) + } + } + } + + test("InsertTable: overwrite - static clause") { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string, p1 int) USING foo PARTITIONED BY (p1)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 23), (4L, 'keep', 4)") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (p1 = 23) TABLE source") + checkAnswer(spark.table(t1), + sql("SELECT id, data, 23 FROM source UNION SELECT 4L, 'keep', 4")) + } + } + + test("InsertTable: overwrite - mixed clause - static mode") { + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.STATIC.toString) { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 4)") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (p = 2) TABLE source") + checkAnswer(spark.table(t1), + sql("SELECT id, data, 2 FROM source UNION SELECT 4L, 'keep', 4")) + } + } + } + + test("InsertTable: overwrite - mixed clause - dynamic mode") { + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 4)") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (p = 2, id) TABLE source") + checkAnswer(spark.table(t1), + sql("SELECT id, data, 2 FROM source UNION SELECT 4L, 'keep', 4")) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala index 95398082b580..378cc17e449a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala @@ -24,12 +24,13 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import org.apache.spark.sql.catalog.v2.{CatalogV2Implicits, Identifier, StagingTableCatalog, TableCatalog, TableChange} -import org.apache.spark.sql.catalog.v2.expressions.Transform +import org.apache.spark.sql.catalog.v2.expressions.{IdentityTransform, Transform} import org.apache.spark.sql.catalog.v2.utils.CatalogV2Util import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.sources.{EqualTo, Filter} import org.apache.spark.sql.sources.v2.reader.{Batch, InputPartition, PartitionReader, PartitionReaderFactory, Scan, ScanBuilder} -import org.apache.spark.sql.sources.v2.writer.{BatchWrite, DataWriter, DataWriterFactory, SupportsTruncate, WriteBuilder, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.{BatchWrite, DataWriter, DataWriterFactory, SupportsDynamicOverwrite, SupportsOverwrite, SupportsTruncate, WriteBuilder, WriterCommitMessage} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -70,12 +71,8 @@ class TestInMemoryTableCatalog extends TableCatalog { throw new TableAlreadyExistsException(ident) } TestInMemoryTableCatalog.maybeSimulateFailedTableCreation(properties) - if (partitions.nonEmpty) { - throw new UnsupportedOperationException( - s"Catalog $name: Partitioned tables are not supported") - } - val table = new InMemoryTable(s"$name.${ident.quoted}", schema, properties) + val table = new InMemoryTable(s"$name.${ident.quoted}", schema, partitions, properties) tables.put(ident, table) @@ -93,7 +90,8 @@ class TestInMemoryTableCatalog extends TableCatalog { throw new IllegalArgumentException(s"Cannot drop all fields") } - val newTable = new InMemoryTable(table.name, schema, properties, table.data) + val newTable = new InMemoryTable(table.name, schema, table.partitioning, properties) + .withData(table.data) tables.put(ident, newTable) @@ -118,28 +116,43 @@ class TestInMemoryTableCatalog extends TableCatalog { class InMemoryTable( val name: String, val schema: StructType, + override val partitioning: Array[Transform], override val properties: util.Map[String, String]) extends Table with SupportsRead with SupportsWrite { - def this( - name: String, - schema: StructType, - properties: util.Map[String, String], - data: Array[BufferedRows]) = { - this(name, schema, properties) - replaceData(data) + partitioning.foreach { t => + if (!t.isInstanceOf[IdentityTransform]) { + throw new IllegalArgumentException(s"Transform $t must be IdentityTransform") + } } - def rows: Seq[InternalRow] = data.flatMap(_.rows) + @volatile var dataMap: mutable.Map[Seq[Any], BufferedRows] = mutable.Map.empty + + def data: Array[BufferedRows] = dataMap.values.toArray - @volatile var data: Array[BufferedRows] = Array.empty + def rows: Seq[InternalRow] = dataMap.values.flatMap(_.rows).toSeq - def replaceData(buffers: Array[BufferedRows]): Unit = synchronized { - data = buffers + private val partFieldNames = partitioning.flatMap(_.references).toSeq.flatMap(_.fieldNames) + private val partIndexes = partFieldNames.map(schema.fieldIndex(_)) + + private def getKey(row: InternalRow): Seq[Any] = partIndexes.map(row.toSeq(schema)(_)) + + def withData(data: Array[BufferedRows]): InMemoryTable = dataMap.synchronized { + data.foreach(_.rows.foreach { row => + val key = getKey(row) + dataMap += dataMap.get(key) + .map(key -> _.withRow(row)) + .getOrElse(key -> new BufferedRows().withRow(row)) + }) + this } override def capabilities: util.Set[TableCapability] = Set( - TableCapability.BATCH_READ, TableCapability.BATCH_WRITE, TableCapability.TRUNCATE).asJava + TableCapability.BATCH_READ, + TableCapability.BATCH_WRITE, + TableCapability.OVERWRITE_BY_FILTER, + TableCapability.OVERWRITE_DYNAMIC, + TableCapability.TRUNCATE).asJava override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { () => new InMemoryBatchScan(data.map(_.asInstanceOf[InputPartition])) @@ -157,43 +170,79 @@ class InMemoryTable( override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = { TestInMemoryTableCatalog.maybeSimulateFailedTableWrite(options) - new WriteBuilder with SupportsTruncate { - private var shouldTruncate: Boolean = false + + new WriteBuilder with SupportsTruncate with SupportsOverwrite with SupportsDynamicOverwrite { + private var writer: BatchWrite = Append override def truncate(): WriteBuilder = { - shouldTruncate = true + assert(writer == Append) + writer = TruncateAndAppend this } - override def buildForBatch(): BatchWrite = { - if (shouldTruncate) TruncateAndAppend else Append + override def overwrite(filters: Array[Filter]): WriteBuilder = { + assert(writer == Append) + writer = new Overwrite(filters) + this + } + + override def overwriteDynamicPartitions(): WriteBuilder = { + assert(writer == Append) + writer = DynamicOverwrite + this } + + override def buildForBatch(): BatchWrite = writer } } - private object TruncateAndAppend extends BatchWrite { + private abstract class TestBatchWrite extends BatchWrite { override def createBatchWriterFactory(): DataWriterFactory = { BufferedRowsWriterFactory } - override def commit(messages: Array[WriterCommitMessage]): Unit = { - replaceData(messages.map(_.asInstanceOf[BufferedRows])) + override def abort(messages: Array[WriterCommitMessage]): Unit = { } + } - override def abort(messages: Array[WriterCommitMessage]): Unit = { + private object Append extends TestBatchWrite { + override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { + withData(messages.map(_.asInstanceOf[BufferedRows])) } } - private object Append extends BatchWrite { - override def createBatchWriterFactory(): DataWriterFactory = { - BufferedRowsWriterFactory + private object DynamicOverwrite extends TestBatchWrite { + override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { + val newData = messages.map(_.asInstanceOf[BufferedRows]) + dataMap --= newData.flatMap(_.rows.map(getKey)) + withData(newData) } + } - override def commit(messages: Array[WriterCommitMessage]): Unit = { - replaceData(data ++ messages.map(_.asInstanceOf[BufferedRows])) + private class Overwrite(filters: Array[Filter]) extends TestBatchWrite { + override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { + val deleteKeys = dataMap.keys.filter { partValues => + filters.exists { + case EqualTo(attr, value) => + partFieldNames.zipWithIndex.find(_._1 == attr) match { + case Some((_, partIndex)) => + value == partValues(partIndex) + case _ => + throw new IllegalArgumentException(s"Unknown filter attribute: $attr") + } + case f @ _ => + throw new IllegalArgumentException(s"Unsupported filter type: $f") + } + } + dataMap --= deleteKeys + withData(messages.map(_.asInstanceOf[BufferedRows])) } + } - override def abort(messages: Array[WriterCommitMessage]): Unit = { + private object TruncateAndAppend extends TestBatchWrite { + override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { + dataMap = mutable.Map.empty + withData(messages.map(_.asInstanceOf[BufferedRows])) } } } @@ -231,7 +280,7 @@ class TestStagingInMemoryCatalog validateStagedTable(partitions, properties) new TestStagedCreateTable( ident, - new InMemoryTable(s"$name.${ident.quoted}", schema, properties)) + new InMemoryTable(s"$name.${ident.quoted}", schema, partitions, properties)) } override def stageReplace( @@ -242,7 +291,7 @@ class TestStagingInMemoryCatalog validateStagedTable(partitions, properties) new TestStagedReplaceTable( ident, - new InMemoryTable(s"$name.${ident.quoted}", schema, properties)) + new InMemoryTable(s"$name.${ident.quoted}", schema, partitions, properties)) } override def stageCreateOrReplace( @@ -253,7 +302,7 @@ class TestStagingInMemoryCatalog validateStagedTable(partitions, properties) new TestStagedCreateOrReplaceTable( ident, - new InMemoryTable(s"$name.${ident.quoted}", schema, properties)) + new InMemoryTable(s"$name.${ident.quoted}", schema, partitions, properties)) } private def validateStagedTable( @@ -335,6 +384,11 @@ class TestStagingInMemoryCatalog class BufferedRows extends WriterCommitMessage with InputPartition with Serializable { val rows = new mutable.ArrayBuffer[InternalRow]() + + def withRow(row: InternalRow): BufferedRows = { + rows.append(row) + this + } } private object BufferedRowsReaderFactory extends PartitionReaderFactory { From efc4bf72c8355710e1bb4ebe8283aa9760438370 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Fri, 12 Jul 2019 17:22:22 -0700 Subject: [PATCH 2/5] Update for review comments. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 4 +- .../spark/sql/catalyst/dsl/package.scala | 4 +- .../sql/catalyst/parser/AstBuilder.scala | 18 +- ...tement.scala => InsertIntoStatement.scala} | 16 +- .../sql/catalyst/parser/DDLParserSuite.scala | 107 ++++++-- .../sql/catalyst/parser/PlanParserSuite.scala | 23 +- .../datasources/DataSourceResolution.scala | 183 +++++++------ .../sql/sources/v2/DataSourceV2SQLSuite.scala | 248 +++++++++++++++--- .../sources/v2/TestInMemoryTableCatalog.scala | 11 +- 9 files changed, 446 insertions(+), 168 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/{InsertTableStatement.scala => InsertIntoStatement.scala} (82%) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 48739d289f7a..517ef9de4964 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -294,8 +294,8 @@ query ; insertInto - : INSERT OVERWRITE TABLE multipartIdentifier (partitionSpec (IF NOT EXISTS)?)? #insertOverwriteTable - | INSERT INTO TABLE? multipartIdentifier partitionSpec? #insertIntoTable + : INSERT OVERWRITE TABLE? multipartIdentifier (partitionSpec (IF NOT EXISTS)?)? #insertOverwriteTable + | INSERT INTO TABLE? multipartIdentifier partitionSpec? (IF NOT EXISTS)? #insertIntoTable | INSERT OVERWRITE LOCAL? DIRECTORY path=STRING rowFormat? createFileFormat? #insertOverwriteHiveDir | INSERT OVERWRITE LOCAL? DIRECTORY (path=STRING)? tableProvider (OPTIONS options=tablePropertyList)? #insertOverwriteDir ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index fe539348e9ec..796043fff665 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -384,10 +384,10 @@ package object dsl { def insertInto( table: LogicalPlan, - overwrite: Boolean = false, partition: Map[String, Option[String]] = Map.empty, + overwrite: Boolean = false, ifPartitionNotExists: Boolean = false): LogicalPlan = - InsertTableStatement(table, logicalPlan, overwrite, partition, ifPartitionNotExists) + InsertIntoStatement(table, partition, logicalPlan, overwrite, ifPartitionNotExists) def as(alias: String): LogicalPlan = SubqueryAlias(alias, logicalPlan) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 90cfdce67930..c49e2f3d1513 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableAlterColumnStatement, AlterTableDropColumnsStatement, AlterTableRenameColumnStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, AlterViewSetPropertiesStatement, AlterViewUnsetPropertiesStatement, CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement, InsertTableStatement, QualifiedColType, ReplaceTableAsSelectStatement, ReplaceTableStatement} +import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableAlterColumnStatement, AlterTableDropColumnsStatement, AlterTableRenameColumnStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, AlterViewSetPropertiesStatement, AlterViewUnsetPropertiesStatement, CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement, InsertIntoStatement, QualifiedColType, ReplaceTableAsSelectStatement, ReplaceTableStatement} import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -264,19 +264,19 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging ctx match { case table: InsertIntoTableContext => val (tableIdent, partition, ifPartitionNotExists) = visitInsertIntoTable(table) - InsertTableStatement( + InsertIntoStatement( UnresolvedRelation(tableIdent), + partition, query, overwrite = false, - partition, ifPartitionNotExists) case table: InsertOverwriteTableContext => val (tableIdent, partition, ifPartitionNotExists) = visitInsertOverwriteTable(table) - InsertTableStatement( + InsertIntoStatement( UnresolvedRelation(tableIdent), + partition, query, overwrite = true, - partition, ifPartitionNotExists) case dir: InsertOverwriteDirContext => val (isLocal, storage, provider) = visitInsertOverwriteDir(dir) @@ -297,6 +297,10 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging val tableIdent = visitMultipartIdentifier(ctx.multipartIdentifier) val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty) + if (ctx.EXISTS != null) { + operationNotAllowed("INSERT INTO ... IF NOT EXISTS", ctx) + } + (tableIdent, partitionKeys, false) } @@ -311,8 +315,8 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging val dynamicPartitionKeys: Map[String, Option[String]] = partitionKeys.filter(_._2.isEmpty) if (ctx.EXISTS != null && dynamicPartitionKeys.nonEmpty) { - throw new ParseException(s"Dynamic partitions do not support IF NOT EXISTS. Specified " + - "partitions with value: " + dynamicPartitionKeys.keys.mkString("[", ",", "]"), ctx) + operationNotAllowed("IF NOT EXISTS with dynamic partitions: " + + dynamicPartitionKeys.keys.mkString(","), ctx) } (tableIdent, partitionKeys, ctx.EXISTS() != null) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/InsertTableStatement.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/InsertIntoStatement.scala similarity index 82% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/InsertTableStatement.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/InsertIntoStatement.scala index b06d64c560dd..c4210eabe26a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/InsertTableStatement.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/InsertIntoStatement.scala @@ -20,12 +20,12 @@ package org.apache.spark.sql.catalyst.plans.logical.sql import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan /** - * An INSERT TABLE statement, as parsed from SQL. + * An INSERT INTO statement, as parsed from SQL. * * @param table the logical plan representing the table. * @param query the logical plan representing data to write to. * @param overwrite overwrite existing table or partitions. - * @param partition a map from the partition key to the partition value (optional). + * @param partitionSpec a map from the partition key to the partition value (optional). * If the value is missing, dynamic partition insert will be performed. * As an example, `INSERT INTO tbl PARTITION (a=1, b=2) AS` would have * Map('a' -> Some('1'), 'b' -> Some('2')), @@ -34,17 +34,17 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan * @param ifPartitionNotExists If true, only write if the partition does not exist. * Only valid for static partitions. */ -case class InsertTableStatement( +case class InsertIntoStatement( table: LogicalPlan, + partitionSpec: Map[String, Option[String]], query: LogicalPlan, overwrite: Boolean, - partition: Map[String, Option[String]], ifPartitionNotExists: Boolean) extends ParsedStatement { - // IF NOT EXISTS is only valid in INSERT OVERWRITE - assert(overwrite || !ifPartitionNotExists) - // IF NOT EXISTS is only valid in static partitions - assert(partition.values.forall(_.nonEmpty) || !ifPartitionNotExists) + require(overwrite || !ifPartitionNotExists, + "IF NOT EXISTS is only valid in INSERT OVERWRITE") + require(partitionSpec.values.forall(_.nonEmpty) || !ifPartitionNotExists, + "IF NOT EXISTS is only valid with static partitions") override def children: Seq[LogicalPlan] = query :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index 4873e2ee403f..0635f8e5e87e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.catalyst.parser import java.util.Locale +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalog.v2.expressions.{ApplyTransform, BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform} -import org.apache.spark.sql.catalyst.analysis.AnalysisTest +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableAlterColumnStatement, AlterTableDropColumnsStatement, AlterTableRenameColumnStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, AlterViewSetPropertiesStatement, AlterViewUnsetPropertiesStatement, CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement, QualifiedColType, ReplaceTableAsSelectStatement, ReplaceTableStatement} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableAlterColumnStatement, AlterTableDropColumnsStatement, AlterTableRenameColumnStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, AlterViewSetPropertiesStatement, AlterViewUnsetPropertiesStatement, CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement, InsertIntoStatement, QualifiedColType, ReplaceTableAsSelectStatement, ReplaceTableStatement} import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType, TimestampType} import org.apache.spark.unsafe.types.UTF8String @@ -617,14 +617,27 @@ class DDLParserSuite extends AnalysisTest { } } - test("insert table: append") { - parseCompare("INSERT INTO TABLE testcat.ns1.ns2.tbl TABLE source", - table("source").insertInto(table("testcat", "ns1", "ns2", "tbl"))) + test("insert table: basic append") { + Seq( + "INSERT INTO TABLE testcat.ns1.ns2.tbl SELECT * FROM source", + "INSERT INTO testcat.ns1.ns2.tbl SELECT * FROM source" + ).foreach { sql => + parseCompare(sql, + InsertIntoStatement( + UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")), + Map.empty, + Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))), + overwrite = false, ifPartitionNotExists = false)) + } } test("insert table: append from another catalog") { - parseCompare("INSERT INTO TABLE testcat.ns1.ns2.tbl TABLE testcat2.db.tbl", - table("testcat2", "db", "tbl").insertInto(table("testcat", "ns1", "ns2", "tbl"))) + parseCompare("INSERT INTO TABLE testcat.ns1.ns2.tbl SELECT * FROM testcat2.db.tbl", + InsertIntoStatement( + UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")), + Map.empty, + Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("testcat2", "db", "tbl"))), + overwrite = false, ifPartitionNotExists = false)) } test("insert table: append with partition") { @@ -632,17 +645,27 @@ class DDLParserSuite extends AnalysisTest { """ |INSERT INTO testcat.ns1.ns2.tbl |PARTITION (p1 = 3, p2) - |TABLE source + |SELECT * FROM source """.stripMargin, - table("source") - .insertInto( - table("testcat", "ns1", "ns2", "tbl"), - partition = Map("p1" -> Some("3"), "p2" -> None))) + InsertIntoStatement( + UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")), + Map("p1" -> Some("3"), "p2" -> None), + Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))), + overwrite = false, ifPartitionNotExists = false)) } test("insert table: overwrite") { - parseCompare("INSERT OVERWRITE TABLE testcat.ns1.ns2.tbl TABLE source", - table("source").insertInto(table("testcat", "ns1", "ns2", "tbl"), overwrite = true)) + Seq( + "INSERT OVERWRITE TABLE testcat.ns1.ns2.tbl SELECT * FROM source", + "INSERT OVERWRITE testcat.ns1.ns2.tbl SELECT * FROM source" + ).foreach { sql => + parseCompare(sql, + InsertIntoStatement( + UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")), + Map.empty, + Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))), + overwrite = true, ifPartitionNotExists = false)) + } } test("insert table: overwrite with partition") { @@ -650,13 +673,13 @@ class DDLParserSuite extends AnalysisTest { """ |INSERT OVERWRITE TABLE testcat.ns1.ns2.tbl |PARTITION (p1 = 3, p2) - |TABLE source + |SELECT * FROM source """.stripMargin, - table("source") - .insertInto( - table("testcat", "ns1", "ns2", "tbl"), - overwrite = true, - partition = Map("p1" -> Some("3"), "p2" -> None))) + InsertIntoStatement( + UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")), + Map("p1" -> Some("3"), "p2" -> None), + Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))), + overwrite = true, ifPartitionNotExists = false)) } test("insert table: overwrite with partition if not exists") { @@ -664,14 +687,40 @@ class DDLParserSuite extends AnalysisTest { """ |INSERT OVERWRITE TABLE testcat.ns1.ns2.tbl |PARTITION (p1 = 3) IF NOT EXISTS - |TABLE source + |SELECT * FROM source """.stripMargin, - table("source") - .insertInto( - table("testcat", "ns1", "ns2", "tbl"), - overwrite = true, - partition = Map("p1" -> Some("3")), - ifPartitionNotExists = true)) + InsertIntoStatement( + UnresolvedRelation(Seq("testcat", "ns1", "ns2", "tbl")), + Map("p1" -> Some("3")), + Project(Seq(UnresolvedStar(None)), UnresolvedRelation(Seq("source"))), + overwrite = true, ifPartitionNotExists = true)) + } + + test("insert table: if not exists with dynamic partition fails") { + val exc = intercept[AnalysisException] { + parsePlan( + """ + |INSERT OVERWRITE TABLE testcat.ns1.ns2.tbl + |PARTITION (p1 = 3, p2) IF NOT EXISTS + |SELECT * FROM source + """.stripMargin) + } + + assert(exc.getMessage.contains("IF NOT EXISTS with dynamic partitions")) + assert(exc.getMessage.contains("p2")) + } + + test("insert table: if not exists without overwrite fails") { + val exc = intercept[AnalysisException] { + parsePlan( + """ + |INSERT INTO TABLE testcat.ns1.ns2.tbl + |PARTITION (p1 = 3) IF NOT EXISTS + |SELECT * FROM source + """.stripMargin) + } + + assert(exc.getMessage.contains("INSERT INTO ... IF NOT EXISTS")) } private case class TableSpec( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 4ac2480760fd..61f8c3b99149 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAlias, Un import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.sql.InsertIntoStatement import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.IntegerType @@ -184,13 +185,15 @@ class PlanParserSuite extends AnalysisTest { } test("insert into") { + import org.apache.spark.sql.catalyst.dsl.expressions._ + import org.apache.spark.sql.catalyst.dsl.plans._ val sql = "select * from t" val plan = table("t").select(star()) def insert( partition: Map[String, Option[String]], overwrite: Boolean = false, ifPartitionNotExists: Boolean = false): LogicalPlan = - plan.insertInto(table("s"), overwrite, partition, ifPartitionNotExists) + InsertIntoStatement(table("s"), partition, plan, overwrite, ifPartitionNotExists) // Single inserts assertEqual(s"insert overwrite table s $sql", @@ -208,13 +211,6 @@ class PlanParserSuite extends AnalysisTest { plan.limit(1).insertInto("s").union(plan2.insertInto("u"))) } - test ("insert with if not exists") { - val sql = "select * from t" - intercept(s"insert overwrite table s partition (e = 1, x) if not exists $sql", - "Dynamic partitions do not support IF NOT EXISTS. Specified partitions with value: [x]") - intercept[ParseException](parsePlan(s"insert overwrite table s if not exists $sql")) - } - test("aggregation") { val sql = "select a, b, sum(c) as c from d group by a, b" @@ -616,12 +612,11 @@ class PlanParserSuite extends AnalysisTest { comparePlans( parsePlan( "INSERT INTO s SELECT /*+ REPARTITION(100), COALESCE(500), COALESCE(10) */ * FROM t"), - table("t") - .select(star()) - .hint("COALESCE", Literal(10)) - .hint("COALESCE", Literal(500)) - .hint("REPARTITION", Literal(100)) - .insertInto("s")) + InsertIntoStatement(table("s"), Map.empty, + UnresolvedHint("REPARTITION", Seq(Literal(100)), + UnresolvedHint("COALESCE", Seq(Literal(500)), + UnresolvedHint("COALESCE", Seq(Literal(10)), + table("t").select(star())))), overwrite = false, ifPartitionNotExists = false)) comparePlans( parsePlan("SELECT /*+ BROADCASTJOIN(u), REPARTITION(100) */ * FROM t"), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala index e06ac8461e52..7293d5af2664 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala @@ -23,19 +23,19 @@ import scala.collection.mutable import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalog.v2.{CatalogPlugin, Identifier, LookupCatalog, TableCatalog} -import org.apache.spark.sql.catalog.v2.expressions.Transform +import org.apache.spark.sql.catalog.v2.expressions.{FieldReference, IdentityTransform, Transform} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{CastSupport, UnresolvedRelation} +import org.apache.spark.sql.catalyst.analysis.{CastSupport, UnresolvedAttribute, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTableType, CatalogUtils, UnresolvedCatalogRelation} -import org.apache.spark.sql.catalyst.expressions.{Alias, And, Cast, EqualTo, Literal} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Cast, EqualTo, Expression, Literal} import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, CreateV2Table, DropTable, InsertIntoTable, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Project, ReplaceTable, ReplaceTableAsSelect} -import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, AlterViewSetPropertiesStatement, AlterViewUnsetPropertiesStatement, CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement, InsertTableStatement, QualifiedColType, ReplaceTableAsSelectStatement, ReplaceTableStatement} +import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, AlterViewSetPropertiesStatement, AlterViewUnsetPropertiesStatement, CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement, InsertIntoStatement, QualifiedColType, ReplaceTableAsSelectStatement, ReplaceTableStatement} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.command.{AlterTableAddColumnsCommand, AlterTableSetLocationCommand, AlterTableSetPropertiesCommand, AlterTableUnsetPropertiesCommand, DropTableCommand} import org.apache.spark.sql.execution.datasources.v2.{CatalogTableAsV2, DataSourceV2Relation} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode -import org.apache.spark.sql.sources.v2.TableProvider +import org.apache.spark.sql.sources.v2.{Table, TableProvider} import org.apache.spark.sql.types.{HIVE_TYPE_STRING, HiveStringType, MetadataBuilder, StructField, StructType} case class DataSourceResolution( @@ -166,94 +166,39 @@ case class DataSourceResolution( case DataSourceV2Relation(CatalogTableAsV2(catalogTable), _, _) => UnresolvedCatalogRelation(catalogTable) - case i @ InsertTableStatement(UnresolvedRelation(CatalogObjectIdentifier(Some(catalog), ident)), + case i @ InsertIntoStatement(UnresolvedRelation(CatalogObjectIdentifier(Some(catalog), ident)), _, _, _, _) if i.query.resolved => loadTable(catalog, ident) .map(DataSourceV2Relation.create) - .map(table => { + .map(relation => { // ifPartitionNotExists is append with validation, but validation is not supported if (i.ifPartitionNotExists) { throw new AnalysisException( - s"Cannot write, IF NOT EXISTS is not supported for table: ${table.table.name}") + s"Cannot write, IF NOT EXISTS is not supported for table: ${relation.table.name}") } - val staticPartitions = i.partition.filter(_._2.isDefined).mapValues(_.get) - - val resolver = conf.resolver - - // add any static value as a literal column - val staticPartitionProjectList = { - // check that the data column counts match - val numColumns = table.output.size - if (numColumns > staticPartitions.size + i.query.output.size) { - throw new AnalysisException(s"Cannot write: too many columns") - } else if (numColumns < staticPartitions.size + i.query.output.size) { - throw new AnalysisException(s"Cannot write: not enough columns") - } - - val staticNames = staticPartitions.keySet - - // for each static name, find the column name it will replace and check for unknowns. - val outputNameToStaticName = staticNames.map(staticName => - table.output.find(col => resolver(col.name, staticName)) match { - case Some(attr) => - attr.name -> staticName - case _ => - throw new AnalysisException( - s"Cannot add static value for unknown column: $staticName") - }).toMap - - // for each output column, add the static value as a literal - // or use the next input column - val queryColumns = i.query.output.iterator - table.output.map { col => - outputNameToStaticName.get(col.name).flatMap(staticPartitions.get) match { - case Some(staticValue) => - Alias(Cast(Literal(staticValue), col.dataType), col.name)() - case _ => - queryColumns.next - } - } - } + val partCols = partitionColumnNames(relation.table) + validatePartitionSpec(partCols, i.partitionSpec) - val dynamicPartitionOverwrite = table.table.partitioning.size > 0 && - staticPartitions.size < table.table.partitioning.size && + val staticPartitions = i.partitionSpec.filter(_._2.isDefined).mapValues(_.get) + val query = addStaticPartitionColumns(relation, i.query, staticPartitions) + val dynamicPartitionOverwrite = partCols.size > staticPartitions.size && conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC - val query = - if (staticPartitions.isEmpty) { - i.query - } else { - Project(staticPartitionProjectList, i.query) - } - - val deleteExpr = - if (staticPartitions.isEmpty) { - Literal(true) - } else { - staticPartitions.map { case (name, value) => - query.output.find(col => resolver(col.name, name)) match { - case Some(attr) => - EqualTo(attr, Cast(Literal(value), attr.dataType)) - case None => - throw new AnalysisException(s"Unknown static partition column: $name") - } - }.toSeq.reduce(And) - } - if (!i.overwrite) { - AppendData.byPosition(table, query) + AppendData.byPosition(relation, query) } else if (dynamicPartitionOverwrite) { - OverwritePartitionsDynamic.byPosition(table, query) + OverwritePartitionsDynamic.byPosition(relation, query) } else { - OverwriteByExpression.byPosition(table, query, deleteExpr) + OverwriteByExpression.byPosition( + relation, query, staticDeleteExpression(relation, staticPartitions)) } }) .getOrElse(i) - case i @ InsertTableStatement(UnresolvedRelation(AsTableIdentifier(_)), _, _, _, _) + case i @ InsertIntoStatement(UnresolvedRelation(AsTableIdentifier(_)), _, _, _, _) if i.query.resolved => - InsertIntoTable(i.table, i.partition, i.query, i.overwrite, i.ifPartitionNotExists) + InsertIntoTable(i.table, i.partitionSpec, i.query, i.overwrite, i.ifPartitionNotExists) } object V1WriteProvider { @@ -442,4 +387,94 @@ case class DataSourceResolution( nullable = true, builder.build()) } + + private def partitionColumnNames(table: Table): Seq[String] = { + // get partition column names. in v2, partition columns are columns that are stored using an + // identity partition transform because the partition values and the column values are + // identical. otherwise, partition values are produced by transforming one or more source + // columns and cannot be set directly in a query's PARTITION clause. + table.partitioning.flatMap { + case IdentityTransform(FieldReference(Seq(name))) => Some(name) + case _ => None + } + } + + private def validatePartitionSpec( + partitionColumnNames: Seq[String], + partitionSpec: Map[String, Option[String]]): Unit = { + // check that each partition name is a partition column. otherwise, it is not valid + partitionSpec.keySet.foreach { partitionName => + partitionColumnNames.find(name => conf.resolver(name, partitionName)) match { + case Some(_) => + case None => + throw new AnalysisException( + s"PARTITION clause cannot contain a non-partition column name: $partitionName") + } + } + } + + private def addStaticPartitionColumns( + relation: DataSourceV2Relation, + query: LogicalPlan, + staticPartitions: Map[String, String]): LogicalPlan = { + + if (staticPartitions.isEmpty) { + query + + } else { + // add any static value as a literal column + val withStaticPartitionValues = { + // for each static name, find the column name it will replace and check for unknowns. + val outputNameToStaticName = staticPartitions.keySet.map(staticName => + relation.output.find(col => conf.resolver(col.name, staticName)) match { + case Some(attr) => + attr.name -> staticName + case _ => + throw new AnalysisException( + s"Cannot add static value for unknown column: $staticName") + }).toMap + + val queryColumns = query.output.iterator + + // for each output column, add the static value as a literal, or use the next input + // column. this does not fail if input columns are exhausted and adds remaining columns + // at the end. both cases will be caught by ResolveOutputRelation and will fail the + // query with a helpful error message. + relation.output.flatMap { col => + outputNameToStaticName.get(col.name).flatMap(staticPartitions.get) match { + case Some(staticValue) => + Some(Alias(Cast(Literal(staticValue), col.dataType), col.name)()) + case _ if queryColumns.hasNext => + Some(queryColumns.next) + case _ => + None + } + } ++ queryColumns + } + + Project(withStaticPartitionValues, query) + } + } + + private def staticDeleteExpression( + relation: DataSourceV2Relation, + staticPartitions: Map[String, String]): Expression = { + if (staticPartitions.isEmpty) { + Literal(true) + } else { + staticPartitions.map { case (name, value) => + relation.output.find(col => conf.resolver(col.name, name)) match { + case Some(attr) => + // the delete expression must reference the table's column names, but these attributes + // are not available when CheckAnalysis runs because the relation is not a child of the + // logical operation. instead, expressions are resolved after ResolveOutputRelation + // runs, using the query's column names that will match the table names at that point. + // because resolution happens after a future rule, create an UnresolvedAttribute. + EqualTo(UnresolvedAttribute(attr.name), Cast(Literal(value), attr.dataType)) + case None => + throw new AnalysisException(s"Unknown static partition column: $name") + } + }.reduce(And) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala index e274ca72ff73..681c774beba0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._ import org.scalatest.BeforeAndAfter import org.apache.spark.SparkException -import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalog.v2.Identifier import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog @@ -1350,7 +1350,7 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAn } } - test("InsertTable: append") { + test("InsertInto: append") { val t1 = "testcat.ns1.ns2.tbl" withTable(t1) { sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo") @@ -1359,7 +1359,7 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAn } } - test("InsertTable: append - across catalog") { + test("InsertInto: append - across catalog") { val t1 = "testcat.ns1.ns2.tbl" val t2 = "testcat2.db.tbl" withTable(t1, t2) { @@ -1370,16 +1370,81 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAn } } - test("InsertTable: append partitioned table - dynamic clause") { + test("InsertInto: append to partitioned table - without PARTITION clause") { val t1 = "testcat.ns1.ns2.tbl" withTable(t1) { sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)") - sql(s"INSERT INTO TABLE $t1 TABLE source") + sql(s"INSERT INTO TABLE $t1 SELECT * FROM source") checkAnswer(spark.table(t1), spark.table("source")) } } - test("InsertTable: append partitioned table - static clause") { + test("InsertInto: append to partitioned table - with PARTITION clause") { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)") + sql(s"INSERT INTO TABLE $t1 PARTITION (id) SELECT * FROM source") + checkAnswer(spark.table(t1), spark.table("source")) + } + } + + test("InsertInto: dynamic PARTITION clause fails with non-partition column") { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)") + + val exc = intercept[AnalysisException] { + sql(s"INSERT INTO TABLE $t1 PARTITION (data) SELECT * FROM source") + } + + assert(spark.table(t1).count === 0) + assert(exc.getMessage.contains("PARTITION clause cannot contain a non-partition column name")) + assert(exc.getMessage.contains("data")) + } + } + + test("InsertInto: static PARTITION clause fails with non-partition column") { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (data)") + + val exc = intercept[AnalysisException] { + sql(s"INSERT INTO TABLE $t1 PARTITION (id=1) SELECT data FROM source") + } + + assert(spark.table(t1).count === 0) + assert(exc.getMessage.contains("PARTITION clause cannot contain a non-partition column name")) + assert(exc.getMessage.contains("id")) + } + } + + test("InsertInto: fails when missing a column") { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string, missing string) USING foo") + val exc = intercept[AnalysisException] { + sql(s"INSERT INTO $t1 SELECT id, data FROM source") + } + + assert(spark.table(t1).count === 0) + assert(exc.getMessage.contains(s"Cannot write to '$t1', not enough data columns")) + } + } + + test("InsertInto: fails when an extra column is present") { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo") + val exc = intercept[AnalysisException] { + sql(s"INSERT INTO $t1 SELECT id, data, 'fruit' FROM source") + } + + assert(spark.table(t1).count === 0) + assert(exc.getMessage.contains(s"Cannot write to '$t1', too many data columns")) + } + } + + test("InsertInto: append to partitioned table - static clause") { val t1 = "testcat.ns1.ns2.tbl" withTable(t1) { sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)") @@ -1388,73 +1453,196 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAn } } - test("InsertTable: overwrite non-partitioned table") { + test("InsertInto: overwrite non-partitioned table") { val t1 = "testcat.ns1.ns2.tbl" withTable(t1) { sql(s"CREATE TABLE $t1 USING foo AS TABLE source") - sql(s"INSERT OVERWRITE TABLE $t1 TABLE source2") + sql(s"INSERT OVERWRITE TABLE $t1 SELECT * FROM source2") checkAnswer(spark.table(t1), spark.table("source2")) } } - test("InsertTable: overwrite - dynamic clause - static mode") { + test("InsertInto: overwrite - dynamic clause - static mode") { withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.STATIC.toString) { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy'), (4L, 'also-deleted')") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id) SELECT * FROM source") + checkAnswer(spark.table(t1), Seq( + Row(1, "a"), + Row(2, "b"), + Row(3, "c"))) + } + } + } + + test("InsertInto: overwrite - dynamic clause - dynamic mode") { + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { val t1 = "testcat.ns1.ns2.tbl" withTable(t1) { sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)") sql(s"INSERT INTO $t1 VALUES (2L, 'dummy'), (4L, 'keep')") - sql(s"INSERT OVERWRITE TABLE $t1 TABLE source") - checkAnswer(spark.table(t1), spark.table("source")) + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id) SELECT * FROM source") + checkAnswer(spark.table(t1), Seq( + Row(1, "a"), + Row(2, "b"), + Row(3, "c"), + Row(4, "keep"))) } } } - test("InsertTable: overwrite - dynamic clause - dynamic mode") { + test("InsertInto: overwrite - missing clause - static mode") { + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.STATIC.toString) { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy'), (4L, 'also-deleted')") + sql(s"INSERT OVERWRITE TABLE $t1 SELECT * FROM source") + checkAnswer(spark.table(t1), Seq( + Row(1, "a"), + Row(2, "b"), + Row(3, "c"))) + } + } + } + + test("InsertInto: overwrite - missing clause - dynamic mode") { withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { val t1 = "testcat.ns1.ns2.tbl" withTable(t1) { sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo PARTITIONED BY (id)") sql(s"INSERT INTO $t1 VALUES (2L, 'dummy'), (4L, 'keep')") - sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id) TABLE source") - checkAnswer(spark.table(t1), - spark.table("source").union(sql("SELECT 4L, 'keep'"))) + sql(s"INSERT OVERWRITE TABLE $t1 SELECT * FROM source") + checkAnswer(spark.table(t1), Seq( + Row(1, "a"), + Row(2, "b"), + Row(3, "c"), + Row(4, "keep"))) } } } - test("InsertTable: overwrite - static clause") { + test("InsertInto: overwrite - static clause") { val t1 = "testcat.ns1.ns2.tbl" withTable(t1) { sql(s"CREATE TABLE $t1 (id bigint, data string, p1 int) USING foo PARTITIONED BY (p1)") - sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 23), (4L, 'keep', 4)") - sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (p1 = 23) TABLE source") - checkAnswer(spark.table(t1), - sql("SELECT id, data, 23 FROM source UNION SELECT 4L, 'keep', 4")) + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 23), (4L, 'keep', 2)") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (p1 = 23) SELECT * FROM source") + checkAnswer(spark.table(t1), Seq( + Row(1, "a", 23), + Row(2, "b", 23), + Row(3, "c", 23), + Row(4, "keep", 2))) + } + } + + test("InsertInto: overwrite - mixed clause - static mode") { + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.STATIC.toString) { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'also-deleted', 2)") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id, p = 2) SELECT * FROM source") + checkAnswer(spark.table(t1), Seq( + Row(1, "a", 2), + Row(2, "b", 2), + Row(3, "c", 2))) + } + } + } + + test("InsertInto: overwrite - mixed clause reordered - static mode") { + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.STATIC.toString) { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'also-deleted', 2)") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (p = 2, id) SELECT * FROM source") + checkAnswer(spark.table(t1), Seq( + Row(1, "a", 2), + Row(2, "b", 2), + Row(3, "c", 2))) + } } } - test("InsertTable: overwrite - mixed clause - static mode") { + test("InsertInto: overwrite - implicit dynamic partition - static mode") { withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.STATIC.toString) { val t1 = "testcat.ns1.ns2.tbl" withTable(t1) { sql(s"CREATE TABLE $t1 (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") - sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 4)") - sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (p = 2) TABLE source") - checkAnswer(spark.table(t1), - sql("SELECT id, data, 2 FROM source UNION SELECT 4L, 'keep', 4")) + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'also-deleted', 2)") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (p = 2) SELECT * FROM source") + checkAnswer(spark.table(t1), Seq( + Row(1, "a", 2), + Row(2, "b", 2), + Row(3, "c", 2))) + } + } + } + + test("InsertInto: overwrite - mixed clause - dynamic mode") { + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 2)") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (p = 2, id) SELECT * FROM source") + checkAnswer(spark.table(t1), Seq( + Row(1, "a", 2), + Row(2, "b", 2), + Row(3, "c", 2), + Row(4, "keep", 2))) + } + } + } + + test("InsertInto: overwrite - mixed clause reordered - dynamic mode") { + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 2)") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id, p = 2) SELECT * FROM source") + checkAnswer(spark.table(t1), Seq( + Row(1, "a", 2), + Row(2, "b", 2), + Row(3, "c", 2), + Row(4, "keep", 2))) + } + } + } + + test("InsertInto: overwrite - implicit dynamic partition - dynamic mode") { + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 2)") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (p = 2) SELECT * FROM source") + checkAnswer(spark.table(t1), Seq( + Row(1, "a", 2), + Row(2, "b", 2), + Row(3, "c", 2), + Row(4, "keep", 2))) } } } - test("InsertTable: overwrite - mixed clause - dynamic mode") { + test("InsertInto: overwrite - multiple static partitions - dynamic mode") { withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { val t1 = "testcat.ns1.ns2.tbl" withTable(t1) { sql(s"CREATE TABLE $t1 (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)") - sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 4)") - sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (p = 2, id) TABLE source") - checkAnswer(spark.table(t1), - sql("SELECT id, data, 2 FROM source UNION SELECT 4L, 'keep', 4")) + sql(s"INSERT INTO $t1 VALUES (2L, 'dummy', 2), (4L, 'keep', 2)") + sql(s"INSERT OVERWRITE TABLE $t1 PARTITION (id = 2, p = 2) SELECT data FROM source") + checkAnswer(spark.table(t1), Seq( + Row(2, "a", 2), + Row(2, "b", 2), + Row(2, "c", 2), + Row(4, "keep", 2))) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala index 378cc17e449a..8aa6523ac4d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalog.v2.expressions.{IdentityTransform, Transform import org.apache.spark.sql.catalog.v2.utils.CatalogV2Util import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException} -import org.apache.spark.sql.sources.{EqualTo, Filter} +import org.apache.spark.sql.sources.{And, EqualTo, Filter} import org.apache.spark.sql.sources.v2.reader.{Batch, InputPartition, PartitionReader, PartitionReaderFactory, Scan, ScanBuilder} import org.apache.spark.sql.sources.v2.writer.{BatchWrite, DataWriter, DataWriterFactory, SupportsDynamicOverwrite, SupportsOverwrite, SupportsTruncate, WriteBuilder, WriterCommitMessage} import org.apache.spark.sql.types.StructType @@ -222,7 +222,7 @@ class InMemoryTable( private class Overwrite(filters: Array[Filter]) extends TestBatchWrite { override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { val deleteKeys = dataMap.keys.filter { partValues => - filters.exists { + filters.flatMap(splitAnd).forall { case EqualTo(attr, value) => partFieldNames.zipWithIndex.find(_._1 == attr) match { case Some((_, partIndex)) => @@ -237,6 +237,13 @@ class InMemoryTable( dataMap --= deleteKeys withData(messages.map(_.asInstanceOf[BufferedRows])) } + + private def splitAnd(filter: Filter): Seq[Filter] = { + filter match { + case And(left, right) => splitAnd(left) ++ splitAnd(right) + case _ => filter :: Nil + } + } } private object TruncateAndAppend extends TestBatchWrite { From 8449d6689e763b7e33c1d3d83429f24d55705def Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Tue, 16 Jul 2019 17:33:41 -0700 Subject: [PATCH 3/5] Update error message to fix failing test. --- .../org/apache/spark/sql/catalyst/parser/AstBuilder.scala | 2 +- .../src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index c49e2f3d1513..7d1ff153aef2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -316,7 +316,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging val dynamicPartitionKeys: Map[String, Option[String]] = partitionKeys.filter(_._2.isEmpty) if (ctx.EXISTS != null && dynamicPartitionKeys.nonEmpty) { operationNotAllowed("IF NOT EXISTS with dynamic partitions: " + - dynamicPartitionKeys.keys.mkString(","), ctx) + dynamicPartitionKeys.keys.mkString(", "), ctx) } (tableIdent, partitionKeys, ctx.EXISTS() != null) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala index 70307ed7e830..73f5bbd88624 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala @@ -201,8 +201,7 @@ class InsertSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter |SELECT 7, 8, 3 """.stripMargin) } - assert(e.getMessage.contains( - "Dynamic partitions do not support IF NOT EXISTS. Specified partitions with value: [c]")) + assert(e.getMessage.contains("IF NOT EXISTS with dynamic partitions: c")) // If the partition already exists, the insert will overwrite the data // unless users specify IF NOT EXISTS From 97dc04c32c230709e1e25fafbb9cc77f0913ef92 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Wed, 17 Jul 2019 13:34:02 -0700 Subject: [PATCH 4/5] Move InsertInto rules into Analyzer to fix error messages. --- .../sql/catalyst/analysis/Analyzer.scala | 137 +++++++++++++++++- .../datasources/DataSourceResolution.scala | 123 ---------------- 2 files changed, 136 insertions(+), 124 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e55cdfedd323..021fb26bf751 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -25,6 +25,8 @@ import scala.util.Random import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalog.v2.{CatalogNotFoundException, CatalogPlugin, LookupCatalog, TableChange} +import org.apache.spark.sql.catalog.v2.expressions.{FieldReference, IdentityTransform} +import org.apache.spark.sql.catalog.v2.utils.CatalogV2Util.loadTable import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.OuterScopes @@ -34,12 +36,14 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableAlterColumnStatement, AlterTableDropColumnsStatement, AlterTableRenameColumnStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement} +import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableAlterColumnStatement, AlterTableDropColumnsStatement, AlterTableRenameColumnStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, InsertIntoStatement} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode +import org.apache.spark.sql.sources.v2.Table import org.apache.spark.sql.types._ /** @@ -167,6 +171,7 @@ class Analyzer( Batch("Resolution", fixedPoint, ResolveTableValuedFunctions :: ResolveAlterTable :: + ResolveInsertInto :: ResolveTables :: ResolveRelations :: ResolveReferences :: @@ -757,6 +762,136 @@ class Analyzer( } } + object ResolveInsertInto extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case i @ InsertIntoStatement( + UnresolvedRelation(CatalogObjectIdentifier(Some(tableCatalog), ident)), _, _, _, _) + if i.query.resolved => + loadTable(tableCatalog, ident) + .map(DataSourceV2Relation.create) + .map(relation => { + // ifPartitionNotExists is append with validation, but validation is not supported + if (i.ifPartitionNotExists) { + throw new AnalysisException( + s"Cannot write, IF NOT EXISTS is not supported for table: ${relation.table.name}") + } + + val partCols = partitionColumnNames(relation.table) + validatePartitionSpec(partCols, i.partitionSpec) + + val staticPartitions = i.partitionSpec.filter(_._2.isDefined).mapValues(_.get) + val query = addStaticPartitionColumns(relation, i.query, staticPartitions) + val dynamicPartitionOverwrite = partCols.size > staticPartitions.size && + conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC + + if (!i.overwrite) { + AppendData.byPosition(relation, query) + } else if (dynamicPartitionOverwrite) { + OverwritePartitionsDynamic.byPosition(relation, query) + } else { + OverwriteByExpression.byPosition( + relation, query, staticDeleteExpression(relation, staticPartitions)) + } + }) + .getOrElse(i) + + case i @ InsertIntoStatement(UnresolvedRelation(AsTableIdentifier(_)), _, _, _, _) + if i.query.resolved => + InsertIntoTable(i.table, i.partitionSpec, i.query, i.overwrite, i.ifPartitionNotExists) + } + + private def partitionColumnNames(table: Table): Seq[String] = { + // get partition column names. in v2, partition columns are columns that are stored using an + // identity partition transform because the partition values and the column values are + // identical. otherwise, partition values are produced by transforming one or more source + // columns and cannot be set directly in a query's PARTITION clause. + table.partitioning.flatMap { + case IdentityTransform(FieldReference(Seq(name))) => Some(name) + case _ => None + } + } + + private def validatePartitionSpec( + partitionColumnNames: Seq[String], + partitionSpec: Map[String, Option[String]]): Unit = { + // check that each partition name is a partition column. otherwise, it is not valid + partitionSpec.keySet.foreach { partitionName => + partitionColumnNames.find(name => conf.resolver(name, partitionName)) match { + case Some(_) => + case None => + throw new AnalysisException( + s"PARTITION clause cannot contain a non-partition column name: $partitionName") + } + } + } + + private def addStaticPartitionColumns( + relation: DataSourceV2Relation, + query: LogicalPlan, + staticPartitions: Map[String, String]): LogicalPlan = { + + if (staticPartitions.isEmpty) { + query + + } else { + // add any static value as a literal column + val withStaticPartitionValues = { + // for each static name, find the column name it will replace and check for unknowns. + val outputNameToStaticName = staticPartitions.keySet.map(staticName => + relation.output.find(col => conf.resolver(col.name, staticName)) match { + case Some(attr) => + attr.name -> staticName + case _ => + throw new AnalysisException( + s"Cannot add static value for unknown column: $staticName") + }).toMap + + val queryColumns = query.output.iterator + + // for each output column, add the static value as a literal, or use the next input + // column. this does not fail if input columns are exhausted and adds remaining columns + // at the end. both cases will be caught by ResolveOutputRelation and will fail the + // query with a helpful error message. + relation.output.flatMap { col => + outputNameToStaticName.get(col.name).flatMap(staticPartitions.get) match { + case Some(staticValue) => + Some(Alias(Cast(Literal(staticValue), col.dataType), col.name)()) + case _ if queryColumns.hasNext => + Some(queryColumns.next) + case _ => + None + } + } ++ queryColumns + } + + Project(withStaticPartitionValues, query) + } + } + + private def staticDeleteExpression( + relation: DataSourceV2Relation, + staticPartitions: Map[String, String]): Expression = { + if (staticPartitions.isEmpty) { + Literal(true) + } else { + staticPartitions.map { case (name, value) => + relation.output.find(col => conf.resolver(col.name, name)) match { + case Some(attr) => + // the delete expression must reference the table's column names, but these attributes + // are not available when CheckAnalysis runs because the relation is not a child of + // the logical operation. instead, expressions are resolved after + // ResolveOutputRelation runs, using the query's column names that will match the + // table names at that point. because resolution happens after a future rule, create + // an UnresolvedAttribute. + EqualTo(UnresolvedAttribute(attr.name), Cast(Literal(value), attr.dataType)) + case None => + throw new AnalysisException(s"Unknown static partition column: $name") + } + }.reduce(And) + } + } + } + /** * Resolve ALTER TABLE statements that use a DSv2 catalog. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala index 7293d5af2664..7c99d24dbdfd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala @@ -166,39 +166,6 @@ case class DataSourceResolution( case DataSourceV2Relation(CatalogTableAsV2(catalogTable), _, _) => UnresolvedCatalogRelation(catalogTable) - case i @ InsertIntoStatement(UnresolvedRelation(CatalogObjectIdentifier(Some(catalog), ident)), - _, _, _, _) if i.query.resolved => - loadTable(catalog, ident) - .map(DataSourceV2Relation.create) - .map(relation => { - // ifPartitionNotExists is append with validation, but validation is not supported - if (i.ifPartitionNotExists) { - throw new AnalysisException( - s"Cannot write, IF NOT EXISTS is not supported for table: ${relation.table.name}") - } - - val partCols = partitionColumnNames(relation.table) - validatePartitionSpec(partCols, i.partitionSpec) - - val staticPartitions = i.partitionSpec.filter(_._2.isDefined).mapValues(_.get) - val query = addStaticPartitionColumns(relation, i.query, staticPartitions) - val dynamicPartitionOverwrite = partCols.size > staticPartitions.size && - conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC - - if (!i.overwrite) { - AppendData.byPosition(relation, query) - } else if (dynamicPartitionOverwrite) { - OverwritePartitionsDynamic.byPosition(relation, query) - } else { - OverwriteByExpression.byPosition( - relation, query, staticDeleteExpression(relation, staticPartitions)) - } - }) - .getOrElse(i) - - case i @ InsertIntoStatement(UnresolvedRelation(AsTableIdentifier(_)), _, _, _, _) - if i.query.resolved => - InsertIntoTable(i.table, i.partitionSpec, i.query, i.overwrite, i.ifPartitionNotExists) } object V1WriteProvider { @@ -387,94 +354,4 @@ case class DataSourceResolution( nullable = true, builder.build()) } - - private def partitionColumnNames(table: Table): Seq[String] = { - // get partition column names. in v2, partition columns are columns that are stored using an - // identity partition transform because the partition values and the column values are - // identical. otherwise, partition values are produced by transforming one or more source - // columns and cannot be set directly in a query's PARTITION clause. - table.partitioning.flatMap { - case IdentityTransform(FieldReference(Seq(name))) => Some(name) - case _ => None - } - } - - private def validatePartitionSpec( - partitionColumnNames: Seq[String], - partitionSpec: Map[String, Option[String]]): Unit = { - // check that each partition name is a partition column. otherwise, it is not valid - partitionSpec.keySet.foreach { partitionName => - partitionColumnNames.find(name => conf.resolver(name, partitionName)) match { - case Some(_) => - case None => - throw new AnalysisException( - s"PARTITION clause cannot contain a non-partition column name: $partitionName") - } - } - } - - private def addStaticPartitionColumns( - relation: DataSourceV2Relation, - query: LogicalPlan, - staticPartitions: Map[String, String]): LogicalPlan = { - - if (staticPartitions.isEmpty) { - query - - } else { - // add any static value as a literal column - val withStaticPartitionValues = { - // for each static name, find the column name it will replace and check for unknowns. - val outputNameToStaticName = staticPartitions.keySet.map(staticName => - relation.output.find(col => conf.resolver(col.name, staticName)) match { - case Some(attr) => - attr.name -> staticName - case _ => - throw new AnalysisException( - s"Cannot add static value for unknown column: $staticName") - }).toMap - - val queryColumns = query.output.iterator - - // for each output column, add the static value as a literal, or use the next input - // column. this does not fail if input columns are exhausted and adds remaining columns - // at the end. both cases will be caught by ResolveOutputRelation and will fail the - // query with a helpful error message. - relation.output.flatMap { col => - outputNameToStaticName.get(col.name).flatMap(staticPartitions.get) match { - case Some(staticValue) => - Some(Alias(Cast(Literal(staticValue), col.dataType), col.name)()) - case _ if queryColumns.hasNext => - Some(queryColumns.next) - case _ => - None - } - } ++ queryColumns - } - - Project(withStaticPartitionValues, query) - } - } - - private def staticDeleteExpression( - relation: DataSourceV2Relation, - staticPartitions: Map[String, String]): Expression = { - if (staticPartitions.isEmpty) { - Literal(true) - } else { - staticPartitions.map { case (name, value) => - relation.output.find(col => conf.resolver(col.name, name)) match { - case Some(attr) => - // the delete expression must reference the table's column names, but these attributes - // are not available when CheckAnalysis runs because the relation is not a child of the - // logical operation. instead, expressions are resolved after ResolveOutputRelation - // runs, using the query's column names that will match the table names at that point. - // because resolution happens after a future rule, create an UnresolvedAttribute. - EqualTo(UnresolvedAttribute(attr.name), Cast(Literal(value), attr.dataType)) - case None => - throw new AnalysisException(s"Unknown static partition column: $name") - } - }.reduce(And) - } - } } From 7f193ca074b7fd0dea9f7b28b4e1776819e6d512 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Thu, 25 Jul 2019 09:05:01 -0700 Subject: [PATCH 5/5] Address more review comments. --- .../datasources/DataSourceResolution.scala | 13 +++++-------- .../spark/sql/sources/v2/DataSourceV2SQLSuite.scala | 4 ++-- .../sql/sources/v2/TestInMemoryTableCatalog.scala | 4 ++-- 3 files changed, 9 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala index 7c99d24dbdfd..a51678da2d8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala @@ -23,19 +23,17 @@ import scala.collection.mutable import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalog.v2.{CatalogPlugin, Identifier, LookupCatalog, TableCatalog} -import org.apache.spark.sql.catalog.v2.expressions.{FieldReference, IdentityTransform, Transform} +import org.apache.spark.sql.catalog.v2.expressions.Transform import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{CastSupport, UnresolvedAttribute, UnresolvedRelation} +import org.apache.spark.sql.catalyst.analysis.CastSupport import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTableType, CatalogUtils, UnresolvedCatalogRelation} -import org.apache.spark.sql.catalyst.expressions.{Alias, And, Cast, EqualTo, Expression, Literal} -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, CreateV2Table, DropTable, InsertIntoTable, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Project, ReplaceTable, ReplaceTableAsSelect} -import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, AlterViewSetPropertiesStatement, AlterViewUnsetPropertiesStatement, CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement, InsertIntoStatement, QualifiedColType, ReplaceTableAsSelectStatement, ReplaceTableStatement} +import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, CreateV2Table, DropTable, LogicalPlan, ReplaceTable, ReplaceTableAsSelect} +import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, AlterViewSetPropertiesStatement, AlterViewUnsetPropertiesStatement, CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement, QualifiedColType, ReplaceTableAsSelectStatement, ReplaceTableStatement} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.command.{AlterTableAddColumnsCommand, AlterTableSetLocationCommand, AlterTableSetPropertiesCommand, AlterTableUnsetPropertiesCommand, DropTableCommand} import org.apache.spark.sql.execution.datasources.v2.{CatalogTableAsV2, DataSourceV2Relation} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode -import org.apache.spark.sql.sources.v2.{Table, TableProvider} +import org.apache.spark.sql.sources.v2.TableProvider import org.apache.spark.sql.types.{HIVE_TYPE_STRING, HiveStringType, MetadataBuilder, StructField, StructType} case class DataSourceResolution( @@ -44,7 +42,6 @@ case class DataSourceResolution( extends Rule[LogicalPlan] with CastSupport { import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ - import org.apache.spark.sql.catalog.v2.utils.CatalogV2Util._ import lookup._ lazy val v2SessionCatalog: CatalogPlugin = lookup.sessionCatalog diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala index 681c774beba0..a3e029f53cec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala @@ -1363,7 +1363,7 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAn val t1 = "testcat.ns1.ns2.tbl" val t2 = "testcat2.db.tbl" withTable(t1, t2) { - sql(s"CREATE TABLE $t1 USING foo AS TABLE source") + sql(s"CREATE TABLE $t1 USING foo AS SELECT * FROM source") sql(s"CREATE TABLE $t2 (id bigint, data string) USING foo") sql(s"INSERT INTO $t2 SELECT * FROM $t1") checkAnswer(spark.table(t2), spark.table("source")) @@ -1456,7 +1456,7 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAn test("InsertInto: overwrite non-partitioned table") { val t1 = "testcat.ns1.ns2.tbl" withTable(t1) { - sql(s"CREATE TABLE $t1 USING foo AS TABLE source") + sql(s"CREATE TABLE $t1 USING foo AS SELECT * FROM source") sql(s"INSERT OVERWRITE TABLE $t1 SELECT * FROM source2") checkAnswer(spark.table(t1), spark.table("source2")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala index 8aa6523ac4d8..19a41bee19d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala @@ -230,7 +230,7 @@ class InMemoryTable( case _ => throw new IllegalArgumentException(s"Unknown filter attribute: $attr") } - case f @ _ => + case f => throw new IllegalArgumentException(s"Unsupported filter type: $f") } } @@ -248,7 +248,7 @@ class InMemoryTable( private object TruncateAndAppend extends TestBatchWrite { override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { - dataMap = mutable.Map.empty + dataMap.clear withData(messages.map(_.asInstanceOf[BufferedRows])) } }