-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-18891][SQL] Support for Scala Map collection types #16986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e94a255
8e1d924
b65f6ce
bea90d5
7af9b06
e47abc6
25ec2f0
dbdcb9c
e37e0ca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.InternalRow | |
| import org.apache.spark.sql.catalyst.encoders.RowEncoder | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} | ||
| import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} | ||
| import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData} | ||
| import org.apache.spark.sql.types._ | ||
|
|
||
| /** | ||
|
|
@@ -652,6 +652,173 @@ case class MapObjects private( | |
| } | ||
| } | ||
|
|
||
| object CollectObjectsToMap { | ||
| private val curId = new java.util.concurrent.atomic.AtomicInteger() | ||
|
|
||
| /** | ||
| * Construct an instance of CollectObjectsToMap case class. | ||
| * | ||
| * @param keyFunction The function applied on the key collection elements. | ||
| * @param valueFunction The function applied on the value collection elements. | ||
| * @param inputData An expression that when evaluated returns a map object. | ||
| * @param collClass The type of the resulting collection. | ||
| */ | ||
| def apply( | ||
| keyFunction: Expression => Expression, | ||
| valueFunction: Expression => Expression, | ||
| inputData: Expression, | ||
| collClass: Class[_]): CollectObjectsToMap = { | ||
| val id = curId.getAndIncrement() | ||
| val keyLoopValue = s"CollectObjectsToMap_keyLoopValue$id" | ||
| val mapType = inputData.dataType.asInstanceOf[MapType] | ||
| val keyLoopVar = LambdaVariable(keyLoopValue, "", mapType.keyType, nullable = false) | ||
| val valueLoopValue = s"CollectObjectsToMap_valueLoopValue$id" | ||
| val valueLoopIsNull = s"CollectObjectsToMap_valueLoopIsNull$id" | ||
| val valueLoopVar = LambdaVariable(valueLoopValue, valueLoopIsNull, mapType.valueType) | ||
| CollectObjectsToMap( | ||
| keyLoopValue, keyFunction(keyLoopVar), | ||
| valueLoopValue, valueLoopIsNull, valueFunction(valueLoopVar), | ||
| inputData, collClass) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Expression used to convert a Catalyst Map to an external Scala Map. | ||
| * The collection is constructed using the associated builder, obtained by calling `newBuilder` | ||
| * on the collection's companion object. | ||
| * | ||
| * @param keyLoopValue the name of the loop variable that is used when iterating over the key | ||
| * collection, and which is used as input for the `keyLambdaFunction` | ||
| * @param keyLambdaFunction A function that takes the `keyLoopVar` as input, and is used as | ||
| * a lambda function to handle collection elements. | ||
| * @param valueLoopValue the name of the loop variable that is used when iterating over the value | ||
| * collection, and which is used as input for the `valueLambdaFunction` | ||
| * @param valueLoopIsNull the nullability of the loop variable that is used when iterating over | ||
| * the value collection, and which is used as input for the | ||
| * `valueLambdaFunction` | ||
| * @param valueLambdaFunction A function that takes the `valueLoopVar` as input, and is used as | ||
| * a lambda function to handle collection elements. | ||
| * @param inputData An expression that when evaluated returns a map object. | ||
| * @param collClass The type of the resulting collection. | ||
| */ | ||
| case class CollectObjectsToMap private( | ||
| keyLoopValue: String, | ||
| keyLambdaFunction: Expression, | ||
| valueLoopValue: String, | ||
| valueLoopIsNull: String, | ||
| valueLambdaFunction: Expression, | ||
| inputData: Expression, | ||
| collClass: Class[_]) extends Expression with NonSQLExpression { | ||
|
|
||
| override def nullable: Boolean = inputData.nullable | ||
|
|
||
| override def children: Seq[Expression] = | ||
| keyLambdaFunction :: valueLambdaFunction :: inputData :: Nil | ||
|
|
||
| override def eval(input: InternalRow): Any = | ||
| throw new UnsupportedOperationException("Only code-generated evaluation is supported") | ||
|
|
||
| override def dataType: DataType = ObjectType(collClass) | ||
|
|
||
| override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
| // The data with PythonUserDefinedType are actually stored with the data type of its sqlType. | ||
| // When we want to apply MapObjects on it, we have to use it. | ||
| def inputDataType(dataType: DataType) = dataType match { | ||
| case p: PythonUserDefinedType => p.sqlType | ||
| case _ => dataType | ||
| } | ||
|
|
||
| val mapType = inputDataType(inputData.dataType).asInstanceOf[MapType] | ||
| val keyElementJavaType = ctx.javaType(mapType.keyType) | ||
| ctx.addMutableState(keyElementJavaType, keyLoopValue, "") | ||
| val genKeyFunction = keyLambdaFunction.genCode(ctx) | ||
| val valueElementJavaType = ctx.javaType(mapType.valueType) | ||
| ctx.addMutableState("boolean", valueLoopIsNull, "") | ||
| ctx.addMutableState(valueElementJavaType, valueLoopValue, "") | ||
| val genValueFunction = valueLambdaFunction.genCode(ctx) | ||
| val genInputData = inputData.genCode(ctx) | ||
| val dataLength = ctx.freshName("dataLength") | ||
| val loopIndex = ctx.freshName("loopIndex") | ||
| val tupleLoopValue = ctx.freshName("tupleLoopValue") | ||
| val builderValue = ctx.freshName("builderValue") | ||
|
|
||
| val getLength = s"${genInputData.value}.numElements()" | ||
|
|
||
| val keyArray = ctx.freshName("keyArray") | ||
| val valueArray = ctx.freshName("valueArray") | ||
| val getKeyArray = | ||
| s"${classOf[ArrayData].getName} $keyArray = ${genInputData.value}.keyArray();" | ||
| val getKeyLoopVar = ctx.getValue(keyArray, inputDataType(mapType.keyType), loopIndex) | ||
| val getValueArray = | ||
| s"${classOf[ArrayData].getName} $valueArray = ${genInputData.value}.valueArray();" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: |
||
| val getValueLoopVar = ctx.getValue(valueArray, inputDataType(mapType.valueType), loopIndex) | ||
|
|
||
| // Make a copy of the data if it's unsafe-backed | ||
| def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) = | ||
| s"$value instanceof ${clazz.getSimpleName}? $value.copy() : $value" | ||
| def genFunctionValue(lambdaFunction: Expression, genFunction: ExprCode) = | ||
| lambdaFunction.dataType match { | ||
| case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value) | ||
| case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value) | ||
| case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value) | ||
| case _ => genFunction.value | ||
| } | ||
| val genKeyFunctionValue = genFunctionValue(keyLambdaFunction, genKeyFunction) | ||
| val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction) | ||
|
|
||
| val valueLoopNullCheck = s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);" | ||
|
|
||
| val builderClass = classOf[Builder[_, _]].getName | ||
| val constructBuilder = s""" | ||
| $builderClass $builderValue = ${collClass.getName}$$.MODULE$$.newBuilder(); | ||
| $builderValue.sizeHint($dataLength); | ||
| """ | ||
|
|
||
| val tupleClass = classOf[(_, _)].getName | ||
| val appendToBuilder = s""" | ||
| $tupleClass $tupleLoopValue; | ||
|
|
||
| if (${genValueFunction.isNull}) { | ||
| $tupleLoopValue = new $tupleClass($genKeyFunctionValue, null); | ||
| } else { | ||
| $tupleLoopValue = new $tupleClass($genKeyFunctionValue, $genValueFunctionValue); | ||
| } | ||
|
|
||
| $builderValue.$$plus$$eq($tupleLoopValue); | ||
| """ | ||
| val getBuilderResult = s"${ev.value} = (${collClass.getName}) $builderValue.result();" | ||
|
|
||
| val code = s""" | ||
| ${genInputData.code} | ||
| ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; | ||
|
|
||
| if (!${genInputData.isNull}) { | ||
| int $dataLength = $getLength; | ||
| $constructBuilder | ||
| $getKeyArray | ||
| $getValueArray | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: we can also inline |
||
|
|
||
| int $loopIndex = 0; | ||
| while ($loopIndex < $dataLength) { | ||
| $keyLoopValue = ($keyElementJavaType) ($getKeyLoopVar); | ||
| $valueLoopValue = ($valueElementJavaType) ($getValueLoopVar); | ||
| $valueLoopNullCheck | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can also inline this. The principle is, we should inline these simple codes as many as possible, then when you look at this code block, it's more clear what's going on. |
||
|
|
||
| ${genKeyFunction.code} | ||
| ${genValueFunction.code} | ||
|
|
||
| $appendToBuilder | ||
|
|
||
| $loopIndex += 1; | ||
| } | ||
|
|
||
| $getBuilderResult | ||
| } | ||
| """ | ||
| ev.copy(code = code, isNull = genInputData.isNull) | ||
| } | ||
| } | ||
|
|
||
| object ExternalMapToCatalyst { | ||
| private val curId = new java.util.concurrent.atomic.AtomicInteger() | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,7 @@ | |
| package org.apache.spark.sql | ||
|
|
||
| import scala.collection.immutable.Queue | ||
| import scala.collection.mutable.{LinkedHashMap => LHMap} | ||
| import scala.collection.mutable.ArrayBuffer | ||
|
|
||
| import org.apache.spark.sql.test.SharedSQLContext | ||
|
|
@@ -30,8 +31,14 @@ case class ListClass(l: List[Int]) | |
|
|
||
| case class QueueClass(q: Queue[Int]) | ||
|
|
||
| case class MapClass(m: Map[Int, Int]) | ||
|
|
||
| case class LHMapClass(m: LHMap[Int, Int]) | ||
|
|
||
| case class ComplexClass(seq: SeqClass, list: ListClass, queue: QueueClass) | ||
|
|
||
| case class ComplexMapClass(map: MapClass, lhmap: LHMapClass) | ||
|
|
||
| package object packageobject { | ||
| case class PackageClass(value: Int) | ||
| } | ||
|
|
@@ -258,11 +265,90 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { | |
| ListClass(List(1)) -> Queue("test" -> SeqClass(Seq(2)))) | ||
| } | ||
|
|
||
| test("arbitrary maps") { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: this suite is |
||
| checkDataset(Seq(Map(1 -> 2)).toDS(), Map(1 -> 2)) | ||
| checkDataset(Seq(Map(1.toLong -> 2.toLong)).toDS(), Map(1.toLong -> 2.toLong)) | ||
| checkDataset(Seq(Map(1.toDouble -> 2.toDouble)).toDS(), Map(1.toDouble -> 2.toDouble)) | ||
| checkDataset(Seq(Map(1.toFloat -> 2.toFloat)).toDS(), Map(1.toFloat -> 2.toFloat)) | ||
| checkDataset(Seq(Map(1.toByte -> 2.toByte)).toDS(), Map(1.toByte -> 2.toByte)) | ||
| checkDataset(Seq(Map(1.toShort -> 2.toShort)).toDS(), Map(1.toShort -> 2.toShort)) | ||
| checkDataset(Seq(Map(true -> false)).toDS(), Map(true -> false)) | ||
| checkDataset(Seq(Map("test1" -> "test2")).toDS(), Map("test1" -> "test2")) | ||
| checkDataset(Seq(Map(Tuple1(1) -> Tuple1(2))).toDS(), Map(Tuple1(1) -> Tuple1(2))) | ||
| checkDataset(Seq(Map(1 -> Tuple1(2))).toDS(), Map(1 -> Tuple1(2))) | ||
| checkDataset(Seq(Map("test" -> 2.toLong)).toDS(), Map("test" -> 2.toLong)) | ||
|
|
||
| checkDataset(Seq(LHMap(1 -> 2)).toDS(), LHMap(1 -> 2)) | ||
| checkDataset(Seq(LHMap(1.toLong -> 2.toLong)).toDS(), LHMap(1.toLong -> 2.toLong)) | ||
| checkDataset(Seq(LHMap(1.toDouble -> 2.toDouble)).toDS(), LHMap(1.toDouble -> 2.toDouble)) | ||
| checkDataset(Seq(LHMap(1.toFloat -> 2.toFloat)).toDS(), LHMap(1.toFloat -> 2.toFloat)) | ||
| checkDataset(Seq(LHMap(1.toByte -> 2.toByte)).toDS(), LHMap(1.toByte -> 2.toByte)) | ||
| checkDataset(Seq(LHMap(1.toShort -> 2.toShort)).toDS(), LHMap(1.toShort -> 2.toShort)) | ||
| checkDataset(Seq(LHMap(true -> false)).toDS(), LHMap(true -> false)) | ||
| checkDataset(Seq(LHMap("test1" -> "test2")).toDS(), LHMap("test1" -> "test2")) | ||
| checkDataset(Seq(LHMap(Tuple1(1) -> Tuple1(2))).toDS(), LHMap(Tuple1(1) -> Tuple1(2))) | ||
| checkDataset(Seq(LHMap(1 -> Tuple1(2))).toDS(), LHMap(1 -> Tuple1(2))) | ||
| checkDataset(Seq(LHMap("test" -> 2.toLong)).toDS(), LHMap("test" -> 2.toLong)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we add some nested map cases? e.g.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added as a separate test case (same as sequences) |
||
| } | ||
|
|
||
| ignore("SPARK-19104: map and product combinations") { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why ignore?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added these tests for issue SPARK-19104 CompileException with Map and Case Class in Spark 2.1.0 as I thought I could fix it as part of this PR. However, I found out that it was a more complicated issue than I anticipated so I left the tests there and ignored them. I can remove them. |
||
| // Case classes | ||
| checkDataset(Seq(MapClass(Map(1 -> 2))).toDS(), MapClass(Map(1 -> 2))) | ||
| checkDataset(Seq(Map(1 -> MapClass(Map(2 -> 3)))).toDS(), Map(1 -> MapClass(Map(2 -> 3)))) | ||
| checkDataset(Seq(Map(MapClass(Map(1 -> 2)) -> 3)).toDS(), Map(MapClass(Map(1 -> 2)) -> 3)) | ||
| checkDataset(Seq(Map(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))).toDS(), | ||
| Map(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))) | ||
| checkDataset(Seq(LHMap(1 -> MapClass(Map(2 -> 3)))).toDS(), LHMap(1 -> MapClass(Map(2 -> 3)))) | ||
| checkDataset(Seq(LHMap(MapClass(Map(1 -> 2)) -> 3)).toDS(), LHMap(MapClass(Map(1 -> 2)) -> 3)) | ||
| checkDataset(Seq(LHMap(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))).toDS(), | ||
| LHMap(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))) | ||
|
|
||
| checkDataset(Seq(LHMapClass(LHMap(1 -> 2))).toDS(), LHMapClass(LHMap(1 -> 2))) | ||
| checkDataset(Seq(Map(1 -> LHMapClass(LHMap(2 -> 3)))).toDS(), | ||
| Map(1 -> LHMapClass(LHMap(2 -> 3)))) | ||
| checkDataset(Seq(Map(LHMapClass(LHMap(1 -> 2)) -> 3)).toDS(), | ||
| Map(LHMapClass(LHMap(1 -> 2)) -> 3)) | ||
| checkDataset(Seq(Map(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))).toDS(), | ||
| Map(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))) | ||
| checkDataset(Seq(LHMap(1 -> LHMapClass(LHMap(2 -> 3)))).toDS(), | ||
| LHMap(1 -> LHMapClass(LHMap(2 -> 3)))) | ||
| checkDataset(Seq(LHMap(LHMapClass(LHMap(1 -> 2)) -> 3)).toDS(), | ||
| LHMap(LHMapClass(LHMap(1 -> 2)) -> 3)) | ||
| checkDataset(Seq(LHMap(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))).toDS(), | ||
| LHMap(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))) | ||
|
|
||
| val complex = ComplexMapClass(MapClass(Map(1 -> 2)), LHMapClass(LHMap(3 -> 4))) | ||
| checkDataset(Seq(complex).toDS(), complex) | ||
| checkDataset(Seq(Map(1 -> complex)).toDS(), Map(1 -> complex)) | ||
| checkDataset(Seq(Map(complex -> 5)).toDS(), Map(complex -> 5)) | ||
| checkDataset(Seq(Map(complex -> complex)).toDS(), Map(complex -> complex)) | ||
| checkDataset(Seq(LHMap(1 -> complex)).toDS(), LHMap(1 -> complex)) | ||
| checkDataset(Seq(LHMap(complex -> 5)).toDS(), LHMap(complex -> 5)) | ||
| checkDataset(Seq(LHMap(complex -> complex)).toDS(), LHMap(complex -> complex)) | ||
|
|
||
| // Tuples | ||
| checkDataset(Seq(Map(1 -> 2) -> Map(3 -> 4)).toDS(), Map(1 -> 2) -> Map(3 -> 4)) | ||
| checkDataset(Seq(LHMap(1 -> 2) -> Map(3 -> 4)).toDS(), LHMap(1 -> 2) -> Map(3 -> 4)) | ||
| checkDataset(Seq(Map(1 -> 2) -> LHMap(3 -> 4)).toDS(), Map(1 -> 2) -> LHMap(3 -> 4)) | ||
| checkDataset(Seq(LHMap(1 -> 2) -> LHMap(3 -> 4)).toDS(), LHMap(1 -> 2) -> LHMap(3 -> 4)) | ||
| checkDataset(Seq(LHMap((Map("test1" -> 1) -> 2) -> (3 -> LHMap(4 -> "test2")))).toDS(), | ||
| LHMap((Map("test1" -> 1) -> 2) -> (3 -> LHMap(4 -> "test2")))) | ||
|
|
||
| // Complex | ||
| checkDataset(Seq(LHMapClass(LHMap(1 -> 2)) -> LHMap("test" -> MapClass(Map(3 -> 4)))).toDS(), | ||
| LHMapClass(LHMap(1 -> 2)) -> LHMap("test" -> MapClass(Map(3 -> 4)))) | ||
| } | ||
|
|
||
| test("nested sequences") { | ||
| checkDataset(Seq(Seq(Seq(1))).toDS(), Seq(Seq(1))) | ||
| checkDataset(Seq(List(Queue(1))).toDS(), List(Queue(1))) | ||
| } | ||
|
|
||
| test("nested maps") { | ||
| checkDataset(Seq(Map(1 -> LHMap(2 -> 3))).toDS(), Map(1 -> LHMap(2 -> 3))) | ||
| checkDataset(Seq(LHMap(Map(1 -> 2) -> 3)).toDS(), LHMap(Map(1 -> 2) -> 3)) | ||
| } | ||
|
|
||
| test("package objects") { | ||
| import packageobject._ | ||
| checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1)) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we can inline this
getLength