diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 65a38aacecd43..624c25d95c704 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, DeleteColumn, RenameColumn, UpdateColumnComment, UpdateColumnNullability, UpdateColumnPosition, UpdateColumnType} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -384,6 +384,11 @@ trait CheckAnalysis extends PredicateHelper { failAnalysis(s"Invalid partitioning: ${badReferences.mkString(", ")}") } + create.tableSchema.foreach(f => TypeUtils.failWithIntervalType(f.dataType)) + + case write: V2WriteCommand if write.resolved => + write.query.schema.foreach(f => TypeUtils.failWithIntervalType(f.dataType)) + // If the view output doesn't have the same number of columns neither with the child // output, nor with the query column names, throw an AnalysisException. // If the view's child output can't up cast to the view output, @@ -443,23 +448,27 @@ trait CheckAnalysis extends PredicateHelper { if (parent.nonEmpty) { findField("add to", parent) } + TypeUtils.failWithIntervalType(add.dataType()) case update: UpdateColumnType => val field = findField("update", update.fieldNames) val fieldName = update.fieldNames.quoted update.newDataType match { case _: StructType => - throw new AnalysisException( - s"Cannot update ${table.name} field $fieldName type: " + - s"update a struct by adding, deleting, or updating its fields") + alter.failAnalysis(s"Cannot update ${table.name} field $fieldName type: " + + s"update a struct by updating its fields") case _: MapType => - throw new AnalysisException( - s"Cannot update ${table.name} field $fieldName type: " + - s"update a map by updating $fieldName.key or $fieldName.value") + alter.failAnalysis(s"Cannot update ${table.name} field $fieldName type: " + + s"update a map by updating $fieldName.key or $fieldName.value") case _: ArrayType => - throw new AnalysisException( - s"Cannot update ${table.name} field $fieldName type: " + - s"update the element by updating $fieldName.element") - case _: AtomicType => + alter.failAnalysis(s"Cannot update ${table.name} field $fieldName type: " + + s"update the element by updating $fieldName.element") + case u: UserDefinedType[_] => + alter.failAnalysis(s"Cannot update ${table.name} field $fieldName type: " + + s"update a UserDefinedType[${u.sql}] by updating its fields") + case _: CalendarIntervalType => + alter.failAnalysis(s"Cannot update ${table.name} field $fieldName to " + + s"interval type") + case _ => // update is okay } if (!Cast.canUpCast(field.dataType, update.newDataType)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 9680ea3cd2067..e8266dd401362 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.util +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.RowOrdering import org.apache.spark.sql.types._ @@ -98,4 +99,18 @@ object TypeUtils { case _: AtomicType => true case _ => false } + + def failWithIntervalType(dataType: DataType): Unit = { + dataType match { + case CalendarIntervalType => + throw new AnalysisException("Cannot use interval type in the table schema.") + case ArrayType(et, _) => failWithIntervalType(et) + case MapType(kt, vt, _) => + failWithIntervalType(kt) + failWithIntervalType(vt) + case s: StructType => s.foreach(f => failWithIntervalType(f.dataType)) + case u: UserDefinedType[_] => failWithIntervalType(u.sqlType) + case _ => + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala index d304d5b2ca6a2..8a20ca8a4c187 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/AlterTableTests.scala @@ -116,6 +116,19 @@ trait AlterTableTests extends SharedSparkSession { } } + test("AlterTable: add column with interval type") { + val t = s"${catalogAndNamespace}table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int, point struct) USING $v2Format") + val e1 = + intercept[AnalysisException](sql(s"ALTER TABLE $t ADD COLUMN data interval")) + assert(e1.getMessage.contains("Cannot use interval type in the table schema.")) + val e2 = + intercept[AnalysisException](sql(s"ALTER TABLE $t ADD COLUMN point.z interval")) + assert(e2.getMessage.contains("Cannot use interval type in the table schema.")) + } + } + test("AlterTable: add column with position") { val t = s"${catalogAndNamespace}table_name" withTable(t) { @@ -310,6 +323,15 @@ trait AlterTableTests extends SharedSparkSession { } } + test("AlterTable: update column type to interval") { + val t = s"${catalogAndNamespace}table_name" + withTable(t) { + sql(s"CREATE TABLE $t (id int) USING $v2Format") + val e = intercept[AnalysisException](sql(s"ALTER TABLE $t ALTER COLUMN id TYPE interval")) + assert(e.getMessage.contains("id to interval type")) + } + } + test("AlterTable: SET/DROP NOT NULL") { val t = s"${catalogAndNamespace}table_name" withTable(t) { @@ -358,7 +380,7 @@ trait AlterTableTests extends SharedSparkSession { } assert(exc.getMessage.contains("point")) - assert(exc.getMessage.contains("update a struct by adding, deleting, or updating its fields")) + assert(exc.getMessage.contains("update a struct by updating its fields")) val table = getTableMetadata(t) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala index 72d4629a1a320..0a6897b829994 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala @@ -17,14 +17,19 @@ package org.apache.spark.sql.connector -import org.apache.spark.sql.{DataFrame, Row, SaveMode} +import java.util.Collections + +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode} import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan} +import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.QueryExecutionListener class DataSourceV2DataFrameSuite extends InsertIntoTests(supportsDynamicOverwrite = true, includeSQLOnlyTests = false) { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import testImplicits._ before { @@ -163,4 +168,22 @@ class DataSourceV2DataFrameSuite spark.listenerManager.unregister(listener) } } + + test("Cannot write data with intervals to v2") { + withTable("testcat.table_name") { + val testCatalog = spark.sessionState.catalogManager.catalog("testcat").asTableCatalog + testCatalog.createTable( + Identifier.of(Array(), "table_name"), + new StructType().add("i", "interval"), + Array.empty, Collections.emptyMap[String, String]) + val df = sql("select interval 1 day as i") + val v2Writer = df.writeTo("testcat.table_name") + val e1 = intercept[AnalysisException](v2Writer.append()) + assert(e1.getMessage.contains(s"Cannot use interval type in the table schema.")) + val e2 = intercept[AnalysisException](v2Writer.overwrite(df("i"))) + assert(e2.getMessage.contains(s"Cannot use interval type in the table schema.")) + val e3 = intercept[AnalysisException](v2Writer.overwritePartitions()) + assert(e3.getMessage.contains(s"Cannot use interval type in the table schema.")) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 4c5b1d95b12da..2b08e86fea637 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -252,6 +252,28 @@ class DataSourceV2SQLSuite checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), Seq.empty) } + test("CreateTable/RepalceTable: invalid schema if has interval type") { + Seq("CREATE", "REPLACE").foreach { action => + val e1 = intercept[AnalysisException]( + sql(s"$action TABLE table_name (id int, value interval) USING $v2Format")) + assert(e1.getMessage.contains(s"Cannot use interval type in the table schema.")) + val e2 = intercept[AnalysisException]( + sql(s"$action TABLE table_name (id array) USING $v2Format")) + assert(e2.getMessage.contains(s"Cannot use interval type in the table schema.")) + } + } + + test("CTAS/RTAS: invalid schema if has interval type") { + Seq("CREATE", "REPLACE").foreach { action => + val e1 = intercept[AnalysisException]( + sql(s"$action TABLE table_name USING $v2Format as select interval 1 day")) + assert(e1.getMessage.contains(s"Cannot use interval type in the table schema.")) + val e2 = intercept[AnalysisException]( + sql(s"$action TABLE table_name USING $v2Format as select array(interval 1 day)")) + assert(e2.getMessage.contains(s"Cannot use interval type in the table schema.")) + } + } + test("CreateTableAsSelect: use v2 plan because catalog is set") { val basicCatalog = catalog("testcat").asTableCatalog val atomicCatalog = catalog("testcat_atomic").asTableCatalog