From ae50eec75f212d1840f441fabe3572c2094bdd5a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 18 Oct 2019 23:51:35 -0700 Subject: [PATCH] TRUNCATE TABLE should do multi-catalog resolution. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 14 ++++++++++++ .../catalyst/plans/logical/statements.scala | 7 ++++++ .../sql/catalyst/parser/DDLParserSuite.scala | 10 +++++++++ .../analysis/ResolveSessionCatalog.scala | 8 ++++++- .../spark/sql/execution/SparkSqlParser.scala | 14 ------------ .../sql/connector/DataSourceV2SQLSuite.scala | 22 +++++++++++++++++++ 7 files changed, 61 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 1839203e3b23..4c93f1fe1197 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -211,7 +211,7 @@ statement | CLEAR CACHE #clearCache | LOAD DATA LOCAL? INPATH path=STRING OVERWRITE? INTO TABLE tableIdentifier partitionSpec? #loadData - | TRUNCATE TABLE tableIdentifier partitionSpec? #truncateTable + | TRUNCATE TABLE multipartIdentifier partitionSpec? #truncateTable | MSCK REPAIR TABLE multipartIdentifier #repairTable | op=(ADD | LIST) identifier .*? #manageResource | SET ROLE .*? #failNativeCommand diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 8af7cf9ad800..862903246ed3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2728,4 +2728,18 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging override def visitRepairTable(ctx: RepairTableContext): LogicalPlan = withOrigin(ctx) { RepairTableStatement(visitMultipartIdentifier(ctx.multipartIdentifier())) } + + /** + * Create a [[TruncateTableStatement]] command. + * + * For example: + * {{{ + * TRUNCATE TABLE multi_part_name [PARTITION (partcol1=val1, partcol2=val2 ...)] + * }}} + */ + override def visitTruncateTable(ctx: TruncateTableContext): LogicalPlan = withOrigin(ctx) { + TruncateTableStatement( + visitMultipartIdentifier(ctx.multipartIdentifier), + Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec)) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala index 72d5cbb7d904..1a69a6ab3380 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statements.scala @@ -316,3 +316,10 @@ case class AnalyzeColumnStatement( * A REPAIR TABLE statement, as parsed from SQL */ case class RepairTableStatement(tableName: Seq[String]) extends ParsedStatement + +/** + * A TRUNCATE TABLE statement, as parsed from SQL + */ +case class TruncateTableStatement( + tableName: Seq[String], + partitionSpec: Option[TablePartitionSpec]) extends ParsedStatement diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index 0eaf74f65506..0d87d0ce9b0f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -961,6 +961,16 @@ class DDLParserSuite extends AnalysisTest { RepairTableStatement(Seq("a", "b", "c"))) } + test("TRUNCATE table") { + comparePlans( + parsePlan("TRUNCATE TABLE a.b.c"), + TruncateTableStatement(Seq("a", "b", "c"), None)) + + comparePlans( + parsePlan("TRUNCATE TABLE a.b.c PARTITION(ds='2017-06-10')"), + TruncateTableStatement(Seq("a", "b", "c"), Some(Map("ds" -> "2017-06-10")))) + } + private case class TableSpec( name: Seq[String], schema: Option[StructType], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index 72f539f72008..978214778a4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, LookupCatalog, TableChange, V1Table} import org.apache.spark.sql.connector.expressions.Transform -import org.apache.spark.sql.execution.command.{AlterTableAddColumnsCommand, AlterTableRecoverPartitionsCommand, AlterTableSetLocationCommand, AlterTableSetPropertiesCommand, AlterTableUnsetPropertiesCommand, AnalyzeColumnCommand, AnalyzePartitionCommand, AnalyzeTableCommand, DescribeColumnCommand, DescribeTableCommand, DropTableCommand, ShowTablesCommand} +import org.apache.spark.sql.execution.command.{AlterTableAddColumnsCommand, AlterTableRecoverPartitionsCommand, AlterTableSetLocationCommand, AlterTableSetPropertiesCommand, AlterTableUnsetPropertiesCommand, AnalyzeColumnCommand, AnalyzePartitionCommand, AnalyzeTableCommand, DescribeColumnCommand, DescribeTableCommand, DropTableCommand, ShowTablesCommand, TruncateTableCommand} import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource} import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.internal.SQLConf @@ -282,6 +282,12 @@ class ResolveSessionCatalog( AlterTableRecoverPartitionsCommand( v1TableName.asTableIdentifier, "MSCK REPAIR TABLE") + + case TruncateTableStatement(tableName, partitionSpec) => + val v1TableName = parseV1Table(tableName, "TRUNCATE TABLE") + TruncateTableCommand( + v1TableName.asTableIdentifier, + partitionSpec) } private def parseV1Table(tableName: Seq[String], sql: String): Seq[String] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 3e7a54877cae..235420785aec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -346,20 +346,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { ) } - /** - * Create a [[TruncateTableCommand]] command. - * - * For example: - * {{{ - * TRUNCATE TABLE tablename [PARTITION (partcol1=val1, partcol2=val2 ...)] - * }}} - */ - override def visitTruncateTable(ctx: TruncateTableContext): LogicalPlan = withOrigin(ctx) { - TruncateTableCommand( - visitTableIdentifier(ctx.tableIdentifier), - Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec)) - } - /** * Create a [[CreateDatabaseCommand]] command. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index d253e6078ddc..01c051f15635 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -1210,6 +1210,28 @@ class DataSourceV2SQLSuite } } + test("TRUNCATE TABLE") { + val t = "testcat.ns1.ns2.tbl" + withTable(t) { + sql( + s""" + |CREATE TABLE $t (id bigint, data string) + |USING foo + |PARTITIONED BY (id) + """.stripMargin) + + val e1 = intercept[AnalysisException] { + sql(s"TRUNCATE TABLE $t") + } + assert(e1.message.contains("TRUNCATE TABLE is only supported with v1 tables")) + + val e2 = intercept[AnalysisException] { + sql(s"TRUNCATE TABLE $t PARTITION(id='1')") + } + assert(e2.message.contains("TRUNCATE TABLE is only supported with v1 tables")) + } + } + private def assertAnalysisError(sqlStatement: String, expectedError: String): Unit = { val errMsg = intercept[AnalysisException] { sql(sqlStatement)