diff --git a/examples/scala/src/main/scala/example/EvolutionWithMap.scala b/examples/scala/src/main/scala/example/EvolutionWithMap.scala new file mode 100644 index 00000000000..4b6175f8a15 --- /dev/null +++ b/examples/scala/src/main/scala/example/EvolutionWithMap.scala @@ -0,0 +1,98 @@ +/* + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package example + +import org.apache.spark.sql.types._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.Row +import org.apache.spark.sql.SparkSession + +object EvolutionWithMap { + def main(args: Array[String]): Unit = { + val spark = SparkSession.builder() + .appName("EvolutionWithMap") + .master("local[*]") + .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") + .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") + .getOrCreate() + + import spark.implicits._ + + val tableName = "insert_map_schema_evolution" + + try { + // Define initial schema + val initialSchema = StructType(Seq( + StructField("key", IntegerType, nullable = false), + StructField("metrics", MapType(StringType, StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", IntegerType, nullable = false) + )))) + )) + + val data = Seq( + Row(1, Map("event" -> Row(1, 1))) + ) + + val rdd = spark.sparkContext.parallelize(data) + + val initialDf = spark.createDataFrame(rdd, initialSchema) + + initialDf.write + .option("overwriteSchema", "true") + .mode("overwrite") + .format("delta") + .saveAsTable(s"$tableName") + + // Define the schema with simulteneous change in a StructField name + // And additional field in a map column + val evolvedSchema = StructType(Seq( + StructField("renamed_key", IntegerType, nullable = false), + StructField("metrics", MapType(StringType, StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("value", IntegerType, nullable = false), + StructField("comment", StringType, nullable = true) + )))) + )) + + val evolvedData = Seq( + Row(1, Map("event" -> Row(1, 1, "deprecated"))) + ) + + val evolvedRDD = spark.sparkContext.parallelize(evolvedData) + + val modifiedDf = spark.createDataFrame(evolvedRDD, evolvedSchema) + + // The below would fail without schema evolution for map types + modifiedDf.write + .mode("append") + .option("mergeSchema", "true") + .format("delta") + .insertInto(s"$tableName") + + spark.sql(s"SELECT * FROM $tableName").show(false) + + } finally { + + // Cleanup + spark.sql(s"DROP TABLE IF EXISTS $tableName") + + spark.stop() + } + + } +} diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaAnalysis.scala b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaAnalysis.scala index f20e79a58ed..07d1c6828c0 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaAnalysis.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaAnalysis.scala @@ -69,6 +69,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.util.CaseInsensitiveStringMap + /** * Analysis rules for Delta. Currently, these rules enable schema enforcement / evolution with * INSERT INTO. @@ -930,6 +931,18 @@ class DeltaAnalysis(session: SparkSession) // Keep the type from the query, the target schema will be updated to widen the existing // type to match it. attr + case (s: MapType, t: MapType) if s != t => + // Handle only specific cases where the value type of the MapType is a StructType + // This could be revisited and expanded in the future when needs for more + // nested complex operations are desired vis-a-vis ALTER TABLE COLUMN operations + // for deep nested fields + (s.valueType, t.valueType) match { + case (structS: StructType, structT: StructType) if structS != structT => + addCastsToMaps(tblName, attr, s, t, allowTypeWidening) + case _ => + // Default for all other MapType cases + getCastFunction(attr, targetAttr.dataType, targetAttr.name) + } case _ => getCastFunction(attr, targetAttr.dataType, targetAttr.name) } @@ -1049,6 +1062,7 @@ class DeltaAnalysis(session: SparkSession) /** * Recursively casts structs in case it contains null types. * TODO: Support other complex types like MapType and ArrayType + * The case mapType that calls addCastsToMaps addresses the MapType todo */ private def addCastsToStructs( tableName: String, @@ -1067,6 +1081,8 @@ class DeltaAnalysis(session: SparkSession) val subField = Alias(GetStructField(parent, i, Option(name)), target(i).name)( explicitMetadata = Option(metadata)) addCastsToStructs(tableName, subField, nested, t, allowTypeWidening) + // We could also handle maptype within struct here but there is restriction + // on deep nexted operations that may result in maxIteration error case o => val field = parent.qualifiedName + "." + name val targetName = parent.qualifiedName + "." + target(i).name @@ -1124,6 +1140,48 @@ class DeltaAnalysis(session: SparkSession) DeltaViewHelper.stripTempViewForMerge(plan, conf) } + /** + * Recursively casts maps in case it contains null types. + */ + private def addCastsToMaps( + tableName: String, + parent: NamedExpression, + sourceMapType: MapType, + targetMapType: MapType, + allowTypeWidening: Boolean): Expression = { + // First get keys from the map + val keysExpr = MapKeys(parent) + + // Create a transformation for the values + val transformLambdaFunc = { + val elementVar = NamedLambdaVariable( + "elementVar", sourceMapType.valueType, sourceMapType.valueContainsNull) + val castedValue = sourceMapType.valueType match { + case structType: StructType => + // Handle StructType values + addCastsToStructs( + tableName, + elementVar, + structType, + targetMapType.valueType.asInstanceOf[StructType], + allowTypeWidening + ) + case _ => + // Not expected to get here: see addCastsToColumn + throw new IllegalArgumentException( + s"Target type must be a StructType") + } + + LambdaFunction(castedValue, Seq(elementVar)) + } + + val transformedValues = ArrayTransform( + MapValues(parent), transformLambdaFunc) + + // Create new map from keys and transformed values + MapFromArrays(keysExpr, transformedValues) + } + /** * Verify the input plan for a SINGLE streaming query with the following: * 1. Schema location must be under checkpoint location, if not lifted by flag diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/typewidening/TypeWideningInsertSchemaEvolutionSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/typewidening/TypeWideningInsertSchemaEvolutionSuite.scala index 55cd149a72a..d582c73b672 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/typewidening/TypeWideningInsertSchemaEvolutionSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/typewidening/TypeWideningInsertSchemaEvolutionSuite.scala @@ -297,8 +297,7 @@ trait TypeWideningInsertSchemaEvolutionTests metadata = typeWideningMetadata(version = 1, from = ShortType, to = IntegerType))))) ) - // The next two tests document inconsistencies when handling maps. Using SQL or INSERT by position - // doesn't allow type evolution but using dataframe INSERT by name does. + // maps now allow type evolution for INSERT by position and name in SQL and dataframe. testInserts("nested struct type evolution with field upcast in map")( initialData = TestData( "key int, m map>", @@ -310,10 +309,11 @@ trait TypeWideningInsertSchemaEvolutionTests Seq("""{ "key": 1, "m": { "a": { "x": 3, "y": 4 } } }""")), expectedResult = ExpectedResult.Success(new StructType() .add("key", IntegerType) - // Type evolution wasn't applied in the map. + // Type evolution now works for map. .add("m", MapType(StringType, new StructType() .add("x", IntegerType) - .add("y", ShortType)))), + .add("y", IntegerType, nullable = true, + metadata = typeWideningMetadata(version = 1, from = ShortType, to = IntegerType))))), excludeInserts = insertsDataframe.intersect(insertsByName) )