diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala index 9d6e0a6d6ce6..4b132d8ab6c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala @@ -180,7 +180,7 @@ object ExternalCatalogUtils { } def convertNullPartitionValues(spec: TablePartitionSpec): TablePartitionSpec = { - spec.mapValues(v => if (v == null) DEFAULT_PARTITION_NAME else v).toMap + spec.mapValues(v => if (v == null) DEFAULT_PARTITION_NAME else v).map(identity).toMap } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 0d16f46d049a..1c4db8746c46 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -94,6 +94,15 @@ class InMemoryCatalog( } } + private def toCatalogPartitionSpec = ExternalCatalogUtils.convertNullPartitionValues(_) + private def toCatalogPartitionSpecs(specs: Seq[TablePartitionSpec]): Seq[TablePartitionSpec] = { + specs.map(toCatalogPartitionSpec) + } + private def toCatalogPartitionSpec( + parts: Seq[CatalogTablePartition]): Seq[CatalogTablePartition] = { + parts.map(part => part.copy(spec = toCatalogPartitionSpec(part.spec))) + } + // -------------------------------------------------------------------------- // Databases // -------------------------------------------------------------------------- @@ -389,10 +398,11 @@ class InMemoryCatalog( override def createPartitions( db: String, table: String, - parts: Seq[CatalogTablePartition], + newParts: Seq[CatalogTablePartition], ignoreIfExists: Boolean): Unit = synchronized { requireTableExists(db, table) val existingParts = catalog(db).tables(table).partitions + val parts = toCatalogPartitionSpec(newParts) if (!ignoreIfExists) { val dupSpecs = parts.collect { case p if existingParts.contains(p.spec) => p.spec } if (dupSpecs.nonEmpty) { @@ -428,12 +438,13 @@ class InMemoryCatalog( override def dropPartitions( db: String, table: String, - partSpecs: Seq[TablePartitionSpec], + parts: Seq[TablePartitionSpec], ignoreIfNotExists: Boolean, purge: Boolean, retainData: Boolean): Unit = synchronized { requireTableExists(db, table) val existingParts = catalog(db).tables(table).partitions + val partSpecs = toCatalogPartitionSpecs(parts) if (!ignoreIfNotExists) { val missingSpecs = partSpecs.collect { case s if !existingParts.contains(s) => s } if (missingSpecs.nonEmpty) { @@ -467,8 +478,10 @@ class InMemoryCatalog( override def renamePartitions( db: String, table: String, - specs: Seq[TablePartitionSpec], - newSpecs: Seq[TablePartitionSpec]): Unit = synchronized { + fromSpecs: Seq[TablePartitionSpec], + toSpecs: Seq[TablePartitionSpec]): Unit = synchronized { + val specs = toCatalogPartitionSpecs(fromSpecs) + val newSpecs = toCatalogPartitionSpecs(toSpecs) require(specs.size == newSpecs.size, "number of old and new partition specs differ") requirePartitionsExist(db, table, specs) requirePartitionsNotExist(db, table, newSpecs) @@ -507,7 +520,8 @@ class InMemoryCatalog( override def alterPartitions( db: String, table: String, - parts: Seq[CatalogTablePartition]): Unit = synchronized { + alterParts: Seq[CatalogTablePartition]): Unit = synchronized { + val parts = toCatalogPartitionSpec(alterParts) requirePartitionsExist(db, table, parts.map(p => p.spec)) parts.foreach { p => catalog(db).tables(table).partitions.put(p.spec, p) @@ -517,7 +531,8 @@ class InMemoryCatalog( override def getPartition( db: String, table: String, - spec: TablePartitionSpec): CatalogTablePartition = synchronized { + partSpec: TablePartitionSpec): CatalogTablePartition = synchronized { + val spec = toCatalogPartitionSpec(partSpec) requirePartitionsExist(db, table, Seq(spec)) catalog(db).tables(table).partitions(spec) } @@ -525,7 +540,8 @@ class InMemoryCatalog( override def getPartitionOption( db: String, table: String, - spec: TablePartitionSpec): Option[CatalogTablePartition] = synchronized { + partSpec: TablePartitionSpec): Option[CatalogTablePartition] = synchronized { + val spec = toCatalogPartitionSpec(partSpec) if (!partitionExists(db, table, spec)) { None } else { @@ -536,9 +552,9 @@ class InMemoryCatalog( override def listPartitionNames( db: String, table: String, - partialSpec: Option[TablePartitionSpec] = None): Seq[String] = synchronized { + partSpec: Option[TablePartitionSpec] = None): Seq[String] = synchronized { val partitionColumnNames = getTable(db, table).partitionColumnNames - + val partialSpec = partSpec.map(toCatalogPartitionSpec) listPartitions(db, table, partialSpec).map { partition => partitionColumnNames.map { name => val partValue = if (partition.spec(name) == null) { @@ -557,7 +573,7 @@ class InMemoryCatalog( partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] = synchronized { requireTableExists(db, table) - partialSpec match { + partialSpec.map(toCatalogPartitionSpec) match { case None => catalog(db).tables(table).partitions.values.toSeq case Some(partial) => catalog(db).tables(table).partitions.toSeq.collect { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index d738d3c13a9e..eebff7ef2150 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3868,15 +3868,6 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark assert(unions.size == 1) } - - test("SPARK-33591: null as a partition value") { - val t = "part_table" - withTable(t) { - sql(s"CREATE TABLE $t (col1 INT, p1 STRING) USING PARQUET PARTITIONED BY (p1)") - sql(s"INSERT INTO TABLE $t PARTITION (p1 = null) SELECT 0") - checkAnswer(sql(s"SELECT * FROM $t"), Row(0, null)) - } - } } case class Foo(bar: Option[String]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableDropPartitionSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableDropPartitionSuiteBase.scala index 583448703794..9460cfe5535e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableDropPartitionSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterTableDropPartitionSuiteBase.scala @@ -177,6 +177,26 @@ trait AlterTableDropPartitionSuiteBase extends QueryTest with DDLCommandTestUtil } } + test("SPARK-33591, SPARK-34203: insert and drop partitions with null values") { + def insertAndDropNullPart(t: String, insertCmd: String): Unit = { + sql(s"CREATE TABLE $t (col1 INT, p1 STRING) $defaultUsing PARTITIONED BY (p1)") + sql(insertCmd) + checkPartitions(t, Map("p1" -> nullPartitionValue)) + sql(s"ALTER TABLE $t DROP PARTITION (p1 = null)") + checkPartitions(t) + } + + withNamespaceAndTable("ns", "tbl") { t => + insertAndDropNullPart(t, s"INSERT INTO TABLE $t PARTITION (p1 = null) SELECT 0") + } + + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + withNamespaceAndTable("ns", "tbl") { t => + insertAndDropNullPart(t, s"INSERT OVERWRITE TABLE $t VALUES (0, null)") + } + } + } + test("SPARK-34161, SPARK-34138, SPARK-34099: keep dependents cached after table altering") { withNamespaceAndTable("ns", "tbl") { t => sql(s"CREATE TABLE $t (id int, part int) $defaultUsing PARTITIONED BY (part)")