Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -333,31 +333,15 @@ object ScalaReflection extends ScalaReflection {
// TODO: add walked type path for map
val TypeRef(_, _, Seq(keyType, valueType)) = t

val keyData =
Invoke(
MapObjects(
p => deserializerFor(keyType, Some(p), walkedTypePath),
Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType),
returnNullable = false),
schemaFor(keyType).dataType),
"array",
ObjectType(classOf[Array[Any]]), returnNullable = false)

val valueData =
Invoke(
MapObjects(
p => deserializerFor(valueType, Some(p), walkedTypePath),
Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType),
returnNullable = false),
schemaFor(valueType).dataType),
"array",
ObjectType(classOf[Array[Any]]), returnNullable = false)

StaticInvoke(
ArrayBasedMapData.getClass,
ObjectType(classOf[scala.collection.immutable.Map[_, _]]),
"toScalaMap",
keyData :: valueData :: Nil)
CollectObjectsToMap(
p => deserializerFor(keyType, Some(p), walkedTypePath),
Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should set returnNullable to false here

schemaFor(keyType).dataType,
p => deserializerFor(valueType, Some(p), walkedTypePath),
Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)),
schemaFor(valueType).dataType,
mirror.runtimeClass(t.typeSymbol.asClass)
)

case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) =>
val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.lang.reflect.Modifier
import scala.collection.mutable.Builder
import scala.language.existentials
import scala.reflect.ClassTag
import scala.util.Try

import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.serializer._
Expand Down Expand Up @@ -652,6 +653,299 @@ case class MapObjects private(
}
}

object CollectObjectsToMap {
private val curId = new java.util.concurrent.atomic.AtomicInteger()

/**
* Construct an instance of CollectObjects case class.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CollectObjects -> CollectObjectsToMap

*
* @param keyFunction The function applied on the key collection elements.
* @param keyInputData An expression that when evaluated returns a key collection object.
* @param keyElementType The data type of key elements in the collection.
* @param valueFunction The function applied on the value collection elements.
* @param valueInputData An expression that when evaluated returns a value collection object.
* @param valueElementType The data type of value elements in the collection.
* @param collClass The type of the resulting collection.
*/
def apply(
keyFunction: Expression => Expression,
keyInputData: Expression,
keyElementType: DataType,
valueFunction: Expression => Expression,
valueInputData: Expression,
valueElementType: DataType,
collClass: Class[_]): CollectObjectsToMap = {
val id = curId.getAndIncrement()
val keyLoopValue = s"CollectObjectsToMap_keyLoopValue$id"
val keyLoopIsNull = s"CollectObjectsToMap_keyLoopIsNull$id"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this? the map key can not be null by definition

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. A key in MapData cannot be null. However, since the function takes two ArrayDatas as input, I figured that we shouldn't count on this requirement being necessarily fulfilled. As CollectObjectsToMap is a class separate from its usage in ScalaReflection, I tried to make it as generic and as similar to MapObjects as possible, so it can be used elsewhere without having to make sure additional preconditions are met.
It also produces a generic Map which has implementations that can support null keys. Right now, the only check that prevents this is here. If there is ever a need to support these kinds of Maps in the future, this should make the job easier.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tbh I think this new expression should only be used to turn a catalyst map to external map, and we don't need to generalize it. We can even let it only accept a map type input, instead of 2 array inputs.

val keyLoopVar = LambdaVariable(keyLoopValue, keyLoopIsNull, keyElementType)
val valueLoopValue = s"CollectObjectsToMap_valueLoopValue$id"
val valueLoopIsNull = s"CollectObjectsToMap_valueLoopIsNull$id"
val valueLoopVar = LambdaVariable(valueLoopValue, valueLoopIsNull, valueElementType)
val tupleLoopVar = s"CollectObjectsToMap_tupleLoopValue$id"
val builderValue = s"CollectObjectsToMap_builderValue$id"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We generate name for keyLoopVar and valueLoopVar here because they are used in the keyFunction and valueFunction. The tupleLoopVar and builderValue don't have this problem and we can generate them in class CollectObjectsToMap

CollectObjectsToMap(
keyLoopValue, keyLoopIsNull, keyElementType, keyFunction(keyLoopVar), keyInputData,
valueLoopValue, valueLoopIsNull, valueElementType, valueFunction(valueLoopVar),
valueInputData,
tupleLoopVar, collClass, builderValue)
}
}

/**
* An equivalent to the [[MapObjects]] case class but returning an ObjectType containing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's update the class doc to explicitly say that this expression is used to convert a catalyst map to external map.

* a Scala collection constructed using the associated builder, obtained by calling `newBuilder`
* on the collection's companion object.
*
* @param keyLoopValue the name of the loop variable that is used when iterating over the key
* collection, and which is used as input for the `keyLambdaFunction`
* @param keyLoopIsNull the nullability of the loop variable that is used when iterating over
* the key collection, and which is used as input for the `keyLambdaFunction`
* @param keyLoopVarDataType the data type of the loop variable that is used when iterating over
* the key collection, and which is used as input for the
* `keyLambdaFunction`
* @param keyLambdaFunction A function that takes the `keyLoopVar` as input, and is used as
* a lambda function to handle collection elements.
* @param keyInputData An expression that when evaluated returns a collection object.
* @param valueLoopValue the name of the loop variable that is used when iterating over the value
* collection, and which is used as input for the `valueLambdaFunction`
* @param valueLoopIsNull the nullability of the loop variable that is used when iterating over
* the value collection, and which is used as input for the
* `valueLambdaFunction`
* @param valueLoopVarDataType the data type of the loop variable that is used when iterating over
* the value collection, and which is used as input for the
* `valueLambdaFunction`
* @param valueLambdaFunction A function that takes the `valueLoopVar` as input, and is used as
* a lambda function to handle collection elements.
* @param valueInputData An expression that when evaluated returns a collection object.
* @param tupleLoopValue the name of the loop variable that holds the tuple to be added to the
* resulting map (used only for Scala Map)
* @param collClass The type of the resulting collection.
* @param builderValue The name of the builder variable used to construct the resulting collection.
*/
case class CollectObjectsToMap private(
keyLoopValue: String,
keyLoopIsNull: String,
keyLoopVarDataType: DataType,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need key/value data types as parameters? We can easily get them from key/value input data expression

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I modelled this class after the MapObjects class so that they could be used similarly. I noticed that since then a new UnresolvedMapObjects class was introduced which also doesn't require the element data type. Would this be something similar? And if so, shouldn't I rather introduce a new UnresolvedCollectObjectsToMap class instead?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UnresolvedMapObject is used for dynamic type mapping of array element, but we don't need this for map element.

keyLambdaFunction: Expression,
keyInputData: Expression,
valueLoopValue: String,
valueLoopIsNull: String,
valueLoopVarDataType: DataType,
valueLambdaFunction: Expression,
valueInputData: Expression,
tupleLoopValue: String,
collClass: Class[_],
builderValue: String) extends Expression with NonSQLExpression {

override def nullable: Boolean = keyInputData.nullable

override def children: Seq[Expression] =
keyLambdaFunction :: keyInputData :: valueLambdaFunction :: valueInputData :: Nil

override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported")

override def dataType: DataType = ObjectType(collClass)

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val keyElementJavaType = ctx.javaType(keyLoopVarDataType)
ctx.addMutableState("boolean", keyLoopIsNull, "")
ctx.addMutableState(keyElementJavaType, keyLoopValue, "")
val genKeyInputData = keyInputData.genCode(ctx)
val genKeyFunction = keyLambdaFunction.genCode(ctx)
val valueElementJavaType = ctx.javaType(valueLoopVarDataType)
ctx.addMutableState("boolean", valueLoopIsNull, "")
ctx.addMutableState(valueElementJavaType, valueLoopValue, "")
val genValueInputData = valueInputData.genCode(ctx)
val genValueFunction = valueLambdaFunction.genCode(ctx)
val dataLength = ctx.freshName("dataLength")
val loopIndex = ctx.freshName("loopIndex")

// In RowEncoder, we use `Object` to represent Array or Seq, so we need to determine the type
// of input collection at runtime for this case.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need this here. The key/value arrays come from MapData.getKeyArray, so there is no need to determine the type at runtime because it's always ArrayData

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned earlier, I tried to make this class as generic and similar to MapObjects as possible so it can be used elsewhere without certain preconditions being met. Granted that getting sequences here in the future is unlikely. Should I remove it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea let remove it, see #16986 (comment)

val keySeq = ctx.freshName("keySeq")
val keyArray = ctx.freshName("keyArray")
val valueSeq = ctx.freshName("valueSeq")
val valueArray = ctx.freshName("valueArray")
def determineCollectionType(inputData: Expression, genInputData: ExprCode,
elementJavaType: String, seq: String, array: String) =
inputData.dataType match {
case ObjectType(cls) if cls == classOf[Object] =>
val seqClass = classOf[Seq[_]].getName
s"""
$seqClass $seq = null;
$elementJavaType[] $array = null;
if (${genInputData.value}.getClass().isArray()) {
$array = ($elementJavaType[]) ${genInputData.value};
} else {
$seq = ($seqClass) ${genInputData.value};
}
"""
case _ => ""
}
val determineKeyCollectionType = determineCollectionType(
keyInputData, genKeyInputData, keyElementJavaType, keySeq, keyArray)
val determineValueCollectionType = determineCollectionType(
valueInputData, genValueInputData, valueElementJavaType, valueSeq, valueArray)

// The data with PythonUserDefinedType are actually stored with the data type of its sqlType.
// When we want to apply MapObjects on it, we have to use it.
def inputDataType(inputData: Expression) = inputData.dataType match {
case p: PythonUserDefinedType => p.sqlType
case _ => inputData.dataType
}
val keyInputDataType = inputDataType(keyInputData)
val valueInputDataType = inputDataType(valueInputData)

def lengthAndLoopVar(inputDataType: DataType, genInputData: ExprCode,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can simplify this method to only handle ArrayType

seq: String, array: String) =
inputDataType match {
case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
s"${genInputData.value}.size()" -> s"${genInputData.value}.apply($loopIndex)"
case ObjectType(cls) if cls.isArray =>
s"${genInputData.value}.length" -> s"${genInputData.value}[$loopIndex]"
case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
s"${genInputData.value}.size()" -> s"${genInputData.value}.get($loopIndex)"
case ArrayType(et, _) =>
s"${genInputData.value}.numElements()" -> ctx.getValue(genInputData.value, et, loopIndex)
case ObjectType(cls) if cls == classOf[Object] =>
s"$seq == null ? $array.length : $seq.size()" ->
s"$seq == null ? $array[$loopIndex] : $seq.apply($loopIndex)"
}
val ((getKeyLength, getKeyLoopVar), (getValueLength, getValueLoopVar)) = (
lengthAndLoopVar(inputDataType(keyInputData), genKeyInputData, keySeq, keyArray),
lengthAndLoopVar(inputDataType(valueInputData), genValueInputData, valueSeq, valueArray)
)

// Make a copy of the data if it's unsafe-backed
def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) =
s"$value instanceof ${clazz.getSimpleName}? $value.copy() : $value"
def genFunctionValue(lambdaFunction: Expression, genFunction: ExprCode) =
lambdaFunction.dataType match {
case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value)
case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value)
case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value)
case _ => genFunction.value
}
val genKeyFunctionValue = genFunctionValue(keyLambdaFunction, genKeyFunction)
val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction)

def loopNullCheck(genInputData: ExprCode, inputDataType: DataType,
loopIsNull: String, loopValue: String) =
inputDataType match {
case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);"
// The element of primitive array will never be null.
case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive =>
s"$loopIsNull = false"
case _ => s"$loopIsNull = $loopValue == null;"
}
val keyLoopNullCheck =
loopNullCheck(genKeyInputData, keyInputDataType, keyLoopIsNull, keyLoopValue)
val valueLoopNullCheck =
loopNullCheck(genValueInputData, valueInputDataType, valueLoopIsNull, valueLoopValue)

val constructBuilder = collClass match {
// Scala Map
case cls if classOf[scala.collection.Map[_, _]].isAssignableFrom(cls) =>
val builderClass = classOf[Builder[_, _]].getName
s"""
$builderClass $builderValue = ${collClass.getName}$$.MODULE$$.newBuilder();
$builderValue.sizeHint($dataLength);
"""
// Java Map, AbstractMap => HashMap
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this PR focus on scala map, we don't even have a test for java map. Let's remove these and add them back in the follow-up PR.

case cls if classOf[java.util.Map[_, _]] == cls ||
classOf[java.util.AbstractMap[_, _]] == cls =>
val builderClass = classOf[java.util.HashMap[_, _]].getName
s"$builderClass $builderValue = new $builderClass($dataLength);"
// Java SortedMap, NavigableMap => TreeMap
case cls if classOf[java.util.SortedMap[_, _]] == cls ||
classOf[java.util.NavigableMap[_, _]] == cls =>
val builderClass = classOf[java.util.TreeMap[_, _]].getName
s"$builderClass $builderValue = new $builderClass();"
// Java ConcurrentMap => ConcurrentHashMap
case cls if classOf[java.util.concurrent.ConcurrentMap[_, _]] == cls =>
val builderClass = classOf[java.util.concurrent.ConcurrentHashMap[_, _]].getName
s"$builderClass $builderValue = new $builderClass();"
// Java ConcurrentNavigableMap => ConcurrentSkipListMap
case cls if classOf[java.util.concurrent.ConcurrentNavigableMap[_, _]] == cls =>
val builderClass = classOf[java.util.concurrent.ConcurrentSkipListMap[_, _]].getName
s"$builderClass $builderValue = new $builderClass();"
// Java concrete Map implementation
case cls =>
val builderClass = classOf[java.util.Map[_, _]].getName
// Check for constructor with capacity specification
if (Try(cls.getConstructor(Integer.TYPE)).isSuccess) {
s"$builderClass $builderValue = new ${cls.getName}($dataLength);"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can use the customer map type as the type declaration, then we don't need the cast in https://github.com/apache/spark/pull/16986/files#diff-e436c96ea839dfe446837ab2a3531f93R901

} else {
s"$builderClass $builderValue = new ${cls.getName}();"
}
}

val (appendToBuilder, getBuilderResult) =
if (classOf[scala.collection.Map[_, _]].isAssignableFrom(collClass)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, let's focus on scala map for now.

val tupleClass = classOf[(_, _)].getName
s"""
$tupleClass $tupleLoopValue;

if (${genValueFunction.isNull}) {
$tupleLoopValue = new $tupleClass($genKeyFunctionValue, null);
} else {
$tupleLoopValue = new $tupleClass($genKeyFunctionValue, $genValueFunctionValue);
}

$builderValue.$$plus$$eq($tupleLoopValue);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is ok, but it will be great if there is a way to avoid creating the tuple every time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, MapBuilder only accepts Tuples

""" -> s"${ev.value} = (${collClass.getName}) $builderValue.result();"
} else {
s"$builderValue.put($genKeyFunctionValue, $genValueFunctionValue);" ->
s"${ev.value} = (${collClass.getName}) $builderValue;"
}

val code = s"""
${genKeyInputData.code}
${genValueInputData.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};

if ((${genKeyInputData.isNull} && !${genValueInputData.isNull}) ||
(!${genKeyInputData.isNull} && ${genValueInputData.isNull})) {
throw new RuntimeException("Invalid state: Inconsistent nullability of key-value");
}

if (!${genKeyInputData.isNull}) {
$determineKeyCollectionType
$determineValueCollectionType
if ($getKeyLength != $getValueLength) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need a keyLength and valueLength, just have a mapLength which can be calculated by MapData.numElements

throw new RuntimeException("Invalid state: Inconsistent lengths of key-value arrays");
}
int $dataLength = $getKeyLength;
$constructBuilder

int $loopIndex = 0;
while ($loopIndex < $dataLength) {
$keyLoopValue = ($keyElementJavaType) ($getKeyLoopVar);
$valueLoopValue = ($valueElementJavaType) ($getValueLoopVar);
$keyLoopNullCheck
$valueLoopNullCheck
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can also inline this. The principle is, we should inline these simple codes as many as possible, then when you look at this code block, it's more clear what's going on.


${genKeyFunction.code}
${genValueFunction.code}

if (${genKeyFunction.isNull}) {
throw new RuntimeException("Found null in map key!");
}

$appendToBuilder

$loopIndex += 1;
}

$getBuilderResult
}
"""
ev.copy(code = code, isNull = genKeyInputData.isNull)
}
}

object ExternalMapToCatalyst {
private val curId = new java.util.concurrent.atomic.AtomicInteger()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,31 @@ class ScalaReflectionSuite extends SparkFunSuite {
assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]]))
}

test("serialize and deserialize arbitrary map types") {
val mapSerializer = serializerFor[Map[Int, Int]](BoundReference(
0, ObjectType(classOf[Map[Int, Int]]), nullable = false))
assert(mapSerializer.dataType.head.dataType ==
MapType(IntegerType, IntegerType, valueContainsNull = false))
val mapDeserializer = deserializerFor[Map[Int, Int]]
assert(mapDeserializer.dataType == ObjectType(classOf[Map[_, _]]))

import scala.collection.immutable.HashMap
val hashMapSerializer = serializerFor[HashMap[Int, Int]](BoundReference(
0, ObjectType(classOf[HashMap[Int, Int]]), nullable = false))
assert(hashMapSerializer.dataType.head.dataType ==
MapType(IntegerType, IntegerType, valueContainsNull = false))
val hashMapDeserializer = deserializerFor[HashMap[Int, Int]]
assert(hashMapDeserializer.dataType == ObjectType(classOf[HashMap[_, _]]))

import scala.collection.mutable.{LinkedHashMap => LHMap}
val linkedHashMapSerializer = serializerFor[LHMap[Long, String]](BoundReference(
0, ObjectType(classOf[LHMap[Long, String]]), nullable = false))
assert(linkedHashMapSerializer.dataType.head.dataType ==
MapType(LongType, StringType, valueContainsNull = true))
val linkedHashMapDeserializer = deserializerFor[LHMap[Long, String]]
assert(linkedHashMapDeserializer.dataType == ObjectType(classOf[LHMap[_, _]]))
}

private val dataTypeForComplexData = dataTypeFor[ComplexData]
private val typeOfComplexData = typeOf[ComplexData]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql

import scala.collection.Map
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag

Expand Down Expand Up @@ -166,6 +167,10 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits {
/** @since 2.2.0 */
implicit def newSequenceEncoder[T <: Seq[_] : TypeTag]: Encoder[T] = ExpressionEncoder()

// Maps
/** @since 2.2.0 */
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it needs to be 2.3 now.

implicit def newMapEncoder[T <: Map[_, _] : TypeTag]: Encoder[T] = ExpressionEncoder()

// Arrays

/** @since 1.6.1 */
Expand Down
Loading