diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala index 75a5b10d9ed0..64f57835c889 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Command.scala @@ -17,9 +17,14 @@ package org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.expressions.Attribute + /** * A logical node that represents a non-query command to be executed by the system. For example, * commands can be used by parsers to represent DDL operations. Commands, unlike queries, are * eagerly executed. */ -trait Command +trait Command extends LeafNode { + final override def children: Seq[LogicalPlan] = Seq.empty + override def output: Seq[Attribute] = Seq.empty +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 6df47acaba85..ff1bb126f463 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -31,10 +31,7 @@ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.IntegerType /** A dummy command for testing unsupported operations. */ -case class DummyCommand() extends LogicalPlan with Command { - override def output: Seq[Attribute] = Nil - override def children: Seq[LogicalPlan] = Nil -} +case class DummyCommand() extends Command class UnsupportedOperationsSuite extends SparkFunSuite { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index e7faab549542..ccfb95454b08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -449,7 +449,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { c.bucketSpec, c.mode, c.options, - c.child) + c.query) ExecutedCommandExec(cmd) :: Nil case c: CreateTempViewUsing => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala index b0e2d03af070..af6def52d07d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala @@ -129,6 +129,4 @@ case object ResetCommand extends RunnableCommand with Logging { sparkSession.sessionState.conf.clear() Seq.empty[Row] } - - override val output: Seq[Attribute] = Seq.empty } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala index 697e2ff21159..605e49c5a87c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala @@ -47,8 +47,6 @@ case class CacheTableCommand( Seq.empty[Row] } - - override def output: Seq[Attribute] = Seq.empty } @@ -58,8 +56,6 @@ case class UncacheTableCommand(tableIdent: TableIdentifier) extends RunnableComm sparkSession.catalog.uncacheTable(tableIdent.quotedString) Seq.empty[Row] } - - override def output: Seq[Attribute] = Seq.empty } /** @@ -71,6 +67,4 @@ case object ClearCacheCommand extends RunnableCommand { sparkSession.catalog.clearCache() Seq.empty[Row] } - - override def output: Seq[Attribute] = Seq.empty } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 424a962b5eb1..698c625d617f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -35,9 +35,7 @@ import org.apache.spark.sql.types._ * A logical command that is executed for its side-effects. `RunnableCommand`s are * wrapped in `ExecutedCommand` during execution. */ -trait RunnableCommand extends LogicalPlan with logical.Command { - override def output: Seq[Attribute] = Seq.empty - final override def children: Seq[LogicalPlan] = Seq.empty +trait RunnableCommand extends logical.Command { def run(sparkSession: SparkSession): Seq[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala index 597ec27ce669..e5a6a5f60b8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/databases.scala @@ -59,6 +59,4 @@ case class SetDatabaseCommand(databaseName: String) extends RunnableCommand { sparkSession.sessionState.catalog.setCurrentDatabase(databaseName) Seq.empty[Row] } - - override val output: Seq[Attribute] = Seq.empty } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 16deee359f53..d63d29defdd6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -72,8 +72,6 @@ case class CreateDatabaseCommand( ifNotExists) Seq.empty[Row] } - - override val output: Seq[Attribute] = Seq.empty } @@ -103,8 +101,6 @@ case class DropDatabaseCommand( sparkSession.sessionState.catalog.dropDatabase(databaseName, ifExists, cascade) Seq.empty[Row] } - - override val output: Seq[Attribute] = Seq.empty } /** @@ -128,8 +124,6 @@ case class AlterDatabasePropertiesCommand( Seq.empty[Row] } - - override val output: Seq[Attribute] = Seq.empty } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 9253db0f2082..ad0c779ff29a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -34,7 +34,8 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTableType._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan, UnaryNode} +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} import org.apache.spark.sql.catalyst.util.quoteIdentifier import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils._ import org.apache.spark.sql.execution.datasources.PartitioningUtils @@ -43,10 +44,10 @@ import org.apache.spark.util.Utils case class CreateHiveTableAsSelectLogicalPlan( tableDesc: CatalogTable, - child: LogicalPlan, - allowExisting: Boolean) extends UnaryNode with Command { + query: LogicalPlan, + allowExisting: Boolean) extends Command { - override def output: Seq[Attribute] = Seq.empty[Attribute] + override def innerChildren: Seq[QueryPlan[_]] = Seq(query) override lazy val resolved: Boolean = tableDesc.identifier.database.isDefined && @@ -54,7 +55,7 @@ case class CreateHiveTableAsSelectLogicalPlan( tableDesc.storage.serde.isDefined && tableDesc.storage.inputFormat.isDefined && tableDesc.storage.outputFormat.isDefined && - childrenResolved + query.resolved } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 31a2075d2ff9..857047f820ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.command.RunnableCommand @@ -41,17 +41,10 @@ case class CreateTableUsing( partitionColumns: Array[String], bucketSpec: Option[BucketSpec], allowExisting: Boolean, - managedIfNoPath: Boolean) extends LogicalPlan with logical.Command { - - override def output: Seq[Attribute] = Seq.empty - override def children: Seq[LogicalPlan] = Seq.empty -} + managedIfNoPath: Boolean) extends logical.Command /** * A node used to support CTAS statements and saveAsTable for the data source API. - * This node is a [[logical.UnaryNode]] instead of a [[logical.Command]] because we want the - * analyzer can analyze the logical plan that will be used to populate the table. - * So, [[PreWriteCheck]] can detect cases that are not allowed. */ case class CreateTableUsingAsSelect( tableIdent: TableIdentifier, @@ -60,8 +53,10 @@ case class CreateTableUsingAsSelect( bucketSpec: Option[BucketSpec], mode: SaveMode, options: Map[String, String], - child: LogicalPlan) extends logical.UnaryNode { - override def output: Seq[Attribute] = Seq.empty[Attribute] + query: LogicalPlan) extends logical.Command { + + override def innerChildren: Seq[QueryPlan[_]] = Seq(query) + override lazy val resolved: Boolean = query.resolved } case class CreateTempViewUsing( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 05908d908fd2..27420d58e56a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, RowOrd import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.command.CreateHiveTableAsSelectLogicalPlan import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation} @@ -61,6 +62,25 @@ class ResolveDataSource(sparkSession: SparkSession) extends Rule[LogicalPlan] { } } +/** + * Analyze the query in CREATE TABLE AS SELECT (CTAS). After analysis, [[PreWriteCheck]] also + * can detect the cases that are not allowed. + */ +case class AnalyzeCreateTableAsSelect(sparkSession: SparkSession) extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case c: CreateTableUsingAsSelect if !c.query.resolved => + c.copy(query = analyzeQuery(c.query)) + case c: CreateHiveTableAsSelectLogicalPlan if !c.query.resolved => + c.copy(query = analyzeQuery(c.query)) + } + + private def analyzeQuery(query: LogicalPlan): LogicalPlan = { + val qe = sparkSession.sessionState.executePlan(query) + qe.assertAnalyzed() + qe.analyzed + } +} + /** * Preprocess the [[InsertIntoTable]] plan. Throws exception if the number of columns mismatch, or * specified partition columns are different from the existing partition columns in the target @@ -216,7 +236,7 @@ case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog) // (the relation is a BaseRelation). case l @ LogicalRelation(dest: BaseRelation, _, _) => // Get all input data source relations of the query. - val srcRelations = c.child.collect { + val srcRelations = c.query.collect { case LogicalRelation(src: BaseRelation, _, _) => src } if (srcRelations.contains(dest)) { @@ -233,12 +253,12 @@ case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog) } PartitioningUtils.validatePartitionColumn( - c.child.schema, c.partitionColumns, conf.caseSensitiveAnalysis) + c.query.schema, c.partitionColumns, conf.caseSensitiveAnalysis) for { spec <- c.bucketSpec sortColumnName <- spec.sortColumnNames - sortColumn <- c.child.schema.find(_.name == sortColumnName) + sortColumn <- c.query.schema.find(_.name == sortColumnName) } { if (!RowOrdering.isOrderable(sortColumn.dataType)) { failAnalysis(s"Cannot use ${sortColumn.dataType.simpleString} for sorting column.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index 01cc13f9df88..e054ef2bcf75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.AnalyzeTableCommand -import org.apache.spark.sql.execution.datasources.{DataSourceAnalysis, FindDataSourceTable, PreprocessTableInsertion, ResolveDataSource} +import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.streaming.{StreamingQuery, StreamingQueryManager} import org.apache.spark.sql.util.ExecutionListenerManager @@ -111,6 +111,7 @@ private[sql] class SessionState(sparkSession: SparkSession) { lazy val analyzer: Analyzer = { new Analyzer(catalog, conf) { override val extendedResolutionRules = + AnalyzeCreateTableAsSelect(sparkSession) :: PreprocessTableInsertion(conf) :: new FindDataSourceTable(sparkSession) :: DataSourceAnalysis(conf) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index da5c538eace2..7ab0fe07b9c4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -27,9 +27,11 @@ import scala.util.Random import org.scalatest.Matchers._ import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils._ import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -1565,4 +1567,22 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val d = sampleDf.withColumn("c", monotonically_increasing_id).select($"c").collect assert(d.size == d.distinct.size) } + + test("SPARK-17409: Do Not Optimize Query in CTAS (Data source tables) More Than Once") { + withTable("bar") { + withTempView("foo") { + withSQLConf(SQLConf.DEFAULT_DATA_SOURCE_NAME.key -> "json") { + sql("select 0 as id").createOrReplaceTempView("foo") + val df = sql("select * from foo group by id") + // If we optimize the query in CTAS more than once, the following saveAsTable will fail + // with the error: `GROUP BY position 0 is not in select list (valid range is [1, 1])` + df.write.mode("overwrite").saveAsTable("bar") + checkAnswer(spark.table("bar"), Row(0) :: Nil) + val tableMetadata = spark.sessionState.catalog.getTableMetadata(TableIdentifier("bar")) + assert(tableMetadata.properties(DATASOURCE_PROVIDER) == "json", + "the expected table is a data source table using json") + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala index 251a25665a42..ea1f7a5a3c34 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala @@ -220,4 +220,16 @@ class CreateTableAsSelectSuite Some(BucketSpec(5, Seq("a"), Seq("b")))) } } + + test("SPARK-17409: CTAS of decimal calculation") { + withTable("tab2") { + withTempView("tab1") { + spark.range(99, 101).createOrReplaceTempView("tab1") + val sqlStmt = + "SELECT id, cast(id as long) * cast('1.0' as decimal(38, 18)) as num FROM tab1" + sql(s"CREATE TABLE tab2 USING PARQUET AS $sqlStmt") + checkAnswer(spark.table("tab2"), sql(sqlStmt)) + } + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index bafb42277e33..e7d1ed34f5ab 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -449,7 +449,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log case p: LogicalPlan if !p.childrenResolved => p case p: LogicalPlan if p.resolved => p - case p @ CreateHiveTableAsSelectLogicalPlan(table, child, allowExisting) => + case p @ CreateHiveTableAsSelectLogicalPlan(table, query, allowExisting) => val desc = if (table.storage.serde.isEmpty) { // add default serde table.withNewStorage( @@ -462,7 +462,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log execution.CreateHiveTableAsSelectCommand( desc.copy(identifier = TableIdentifier(tblName, Some(dbName))), - child, + query, allowExisting) } } @@ -510,7 +510,7 @@ private[hive] case class InsertIntoHiveTable( child: LogicalPlan, overwrite: Boolean, ifNotExists: Boolean) - extends LogicalPlan with Command { + extends LogicalPlan { override def children: Seq[LogicalPlan] = child :: Nil override def output: Seq[Attribute] = Seq.empty diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index 8773993d362c..822a7709ba16 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -62,6 +62,7 @@ private[hive] class HiveSessionState(sparkSession: SparkSession) override lazy val analyzer: Analyzer = { new Analyzer(catalog, conf) { override val extendedResolutionRules = + AnalyzeCreateTableAsSelect(sparkSession) :: catalog.ParquetConversions :: catalog.OrcConversions :: catalog.CreateTables :: diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreRelationSuite.scala index eec60b440720..c6711c323625 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreRelationSuite.scala @@ -17,11 +17,14 @@ package org.apache.spark.sql.hive -import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogStorageFormat, CatalogTable, CatalogTableType} +import org.apache.spark.sql.execution.command.DDLUtils +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils -class MetastoreRelationSuite extends SparkFunSuite { +class MetastoreRelationSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("makeCopy and toJSON should work") { val table = CatalogTable( identifier = TableIdentifier("test", Some("db")), @@ -35,4 +38,19 @@ class MetastoreRelationSuite extends SparkFunSuite { // No exception should be thrown relation.toJSON } + + test("SPARK-17409: Do Not Optimize Query in CTAS (Hive Serde Table) More Than Once") { + withTable("bar") { + withTempView("foo") { + sql("select 0 as id").createOrReplaceTempView("foo") + // If we optimize the query in CTAS more than once, the following saveAsTable will fail + // with the error: `GROUP BY position 0 is not in select list (valid range is [1, 1])` + sql("CREATE TABLE bar AS SELECT * FROM foo group by id") + checkAnswer(spark.table("bar"), Row(0) :: Nil) + val tableMetadata = spark.sessionState.catalog.getTableMetadata(TableIdentifier("bar")) + assert(!DDLUtils.isDatasourceTable(tableMetadata), + "the expected table is a Hive serde table") + } + } + } }