diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala index 9cb71bb51294c..0c2a3ac4534d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala @@ -178,7 +178,7 @@ object ExternalCatalogUtils { } def convertNullPartitionValues(spec: TablePartitionSpec): TablePartitionSpec = { - spec.mapValues(v => if (v == null) DEFAULT_PARTITION_NAME else v).toMap + spec.mapValues(v => if (v == null) DEFAULT_PARTITION_NAME else v).map(identity).toMap } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index c05950300279f..90e69469eef69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -94,6 +94,15 @@ class InMemoryCatalog( } } + private def toCatalogPartitionSpec = ExternalCatalogUtils.convertNullPartitionValues(_) + private def toCatalogPartitionSpecs(specs: Seq[TablePartitionSpec]): Seq[TablePartitionSpec] = { + specs.map(toCatalogPartitionSpec) + } + private def toCatalogPartitionSpec( + parts: Seq[CatalogTablePartition]): Seq[CatalogTablePartition] = { + parts.map(part => part.copy(spec = toCatalogPartitionSpec(part.spec))) + } + // -------------------------------------------------------------------------- // Databases // -------------------------------------------------------------------------- @@ -389,10 +398,11 @@ class InMemoryCatalog( override def createPartitions( db: String, table: String, - parts: Seq[CatalogTablePartition], + newParts: Seq[CatalogTablePartition], ignoreIfExists: Boolean): Unit = synchronized { requireTableExists(db, table) val existingParts = catalog(db).tables(table).partitions + val parts = toCatalogPartitionSpec(newParts) if (!ignoreIfExists) { val dupSpecs = parts.collect { case p if existingParts.contains(p.spec) => p.spec } if (dupSpecs.nonEmpty) { @@ -428,12 +438,13 @@ class InMemoryCatalog( override def dropPartitions( db: String, table: String, - partSpecs: Seq[TablePartitionSpec], + parts: Seq[TablePartitionSpec], ignoreIfNotExists: Boolean, purge: Boolean, retainData: Boolean): Unit = synchronized { requireTableExists(db, table) val existingParts = catalog(db).tables(table).partitions + val partSpecs = toCatalogPartitionSpecs(parts) if (!ignoreIfNotExists) { val missingSpecs = partSpecs.collect { case s if !existingParts.contains(s) => s } if (missingSpecs.nonEmpty) { @@ -467,8 +478,10 @@ class InMemoryCatalog( override def renamePartitions( db: String, table: String, - specs: Seq[TablePartitionSpec], - newSpecs: Seq[TablePartitionSpec]): Unit = synchronized { + fromSpecs: Seq[TablePartitionSpec], + toSpecs: Seq[TablePartitionSpec]): Unit = synchronized { + val specs = toCatalogPartitionSpecs(fromSpecs) + val newSpecs = toCatalogPartitionSpecs(toSpecs) require(specs.size == newSpecs.size, "number of old and new partition specs differ") requirePartitionsExist(db, table, specs) requirePartitionsNotExist(db, table, newSpecs) @@ -507,7 +520,8 @@ class InMemoryCatalog( override def alterPartitions( db: String, table: String, - parts: Seq[CatalogTablePartition]): Unit = synchronized { + alterParts: Seq[CatalogTablePartition]): Unit = synchronized { + val parts = toCatalogPartitionSpec(alterParts) requirePartitionsExist(db, table, parts.map(p => p.spec)) parts.foreach { p => catalog(db).tables(table).partitions.put(p.spec, p) @@ -517,7 +531,8 @@ class InMemoryCatalog( override def getPartition( db: String, table: String, - spec: TablePartitionSpec): CatalogTablePartition = synchronized { + partSpec: TablePartitionSpec): CatalogTablePartition = synchronized { + val spec = toCatalogPartitionSpec(partSpec) requirePartitionsExist(db, table, Seq(spec)) catalog(db).tables(table).partitions(spec) } @@ -525,7 +540,8 @@ class InMemoryCatalog( override def getPartitionOption( db: String, table: String, - spec: TablePartitionSpec): Option[CatalogTablePartition] = synchronized { + partSpec: TablePartitionSpec): Option[CatalogTablePartition] = synchronized { + val spec = toCatalogPartitionSpec(partSpec) if (!partitionExists(db, table, spec)) { None } else { @@ -536,9 +552,9 @@ class InMemoryCatalog( override def listPartitionNames( db: String, table: String, - partialSpec: Option[TablePartitionSpec] = None): Seq[String] = synchronized { + partSpec: Option[TablePartitionSpec] = None): Seq[String] = synchronized { val partitionColumnNames = getTable(db, table).partitionColumnNames - + val partialSpec = partSpec.map(toCatalogPartitionSpec) listPartitions(db, table, partialSpec).map { partition => partitionColumnNames.map { name => val partValue = if (partition.spec(name) == null) { @@ -557,7 +573,7 @@ class InMemoryCatalog( partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] = synchronized { requireTableExists(db, table) - partialSpec match { + partialSpec.map(toCatalogPartitionSpec) match { case None => catalog(db).tables(table).partitions.values.toSeq case Some(partial) => catalog(db).tables(table).partitions.toSeq.collect { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 8039c9b6f04b5..fdad69d1f1bbd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3575,15 +3575,6 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } }) } - - test("SPARK-33591: null as a partition value") { - val t = "part_table" - withTable(t) { - sql(s"CREATE TABLE $t (col1 INT, p1 STRING) USING PARQUET PARTITIONED BY (p1)") - sql(s"INSERT INTO TABLE $t PARTITION (p1 = null) SELECT 0") - checkAnswer(sql(s"SELECT * FROM $t"), Row(0, null)) - } - } } case class Foo(bar: Option[String]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 5986cdc78d6b4..847bc668313a4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.connector.catalog.SupportsNamespaces.PROP_OWNER +import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} @@ -1734,9 +1735,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { // null partition values createTablePartition(catalog, Map("a" -> null, "b" -> null), tableIdent) - val nullPartValue = if (isUsingHiveMetastore) "__HIVE_DEFAULT_PARTITION__" else null assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == - Set(Map("a" -> nullPartValue, "b" -> nullPartValue))) + Set(Map("a" -> "__HIVE_DEFAULT_PARTITION__", "b" -> "__HIVE_DEFAULT_PARTITION__"))) sql("ALTER TABLE tab1 DROP PARTITION (a = null, b = null)") assert(catalog.listPartitions(tableIdent).isEmpty) } @@ -3091,6 +3091,35 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { assert(sql(s"SHOW TABLE EXTENDED LIKE '$t' PARTITION(a = 1)").count() === 1) } } + + test("SPARK-33591, SPARK-34203: insert and drop partitions with null values") { + def checkPartitions(t: String, expected: Map[String, String]*): Unit = { + val partitions = sql(s"SHOW PARTITIONS $t") + .collect() + .toSet + .map((row: Row) => row.getString(0)) + .map(PartitioningUtils.parsePathFragment) + assert(partitions === expected.toSet) + } + val defaultUsing = "USING " + (if (isUsingHiveMetastore) "hive" else "parquet") + def insertAndDropNullPart(t: String, insertCmd: String): Unit = { + sql(s"CREATE TABLE $t (col1 INT, p1 STRING) $defaultUsing PARTITIONED BY (p1)") + sql(insertCmd) + checkPartitions(t, Map("p1" -> ExternalCatalogUtils.DEFAULT_PARTITION_NAME)) + sql(s"ALTER TABLE $t DROP PARTITION (p1 = null)") + checkPartitions(t) + } + + withTable("tbl") { + insertAndDropNullPart("tbl", s"INSERT INTO TABLE tbl PARTITION (p1 = null) SELECT 0") + } + + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + withTable("tbl") { + insertAndDropNullPart("tbl", s"INSERT OVERWRITE TABLE tbl VALUES (0, null)") + } + } + } } object FakeLocalFsFileSystem {