Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -307,54 +307,8 @@ object ScalaReflection extends ScalaReflection {
}
}

val array = Invoke(
MapObjects(mapFunction, getPath, dataType),
"array",
ObjectType(classOf[Array[Any]]))

val wrappedArray = StaticInvoke(
scala.collection.mutable.WrappedArray.getClass,
ObjectType(classOf[Seq[_]]),
"make",
array :: Nil)

if (localTypeOf[scala.collection.mutable.WrappedArray[_]] <:< t.erasure) {
wrappedArray
} else {
// Convert to another type using `to`
val cls = mirror.runtimeClass(t.typeSymbol.asClass)
import scala.collection.generic.CanBuildFrom
import scala.reflect.ClassTag

// Some canBuildFrom methods take an implicit ClassTag parameter
val cbfParams = try {
cls.getDeclaredMethod("canBuildFrom", classOf[ClassTag[_]])
StaticInvoke(
ClassTag.getClass,
ObjectType(classOf[ClassTag[_]]),
"apply",
StaticInvoke(
cls,
ObjectType(classOf[Class[_]]),
"getClass"
) :: Nil
) :: Nil
} catch {
case _: NoSuchMethodException => Nil
}

Invoke(
wrappedArray,
"to",
ObjectType(cls),
StaticInvoke(
cls,
ObjectType(classOf[CanBuildFrom[_, _, _]]),
"canBuildFrom",
cbfParams
) :: Nil
)
}
val cls = mirror.runtimeClass(t.typeSymbol.asClass)
CollectObjects(mapFunction, getPath, dataType, cls)

case t if t <:< localTypeOf[Map[_, _]] =>
// TODO: add walked type path for map
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.objects

import java.lang.reflect.Modifier

import scala.collection.mutable.Builder
import scala.language.existentials
import scala.reflect.ClassTag

Expand Down Expand Up @@ -589,6 +590,171 @@ case class MapObjects private(
}
}

object CollectObjects {
private val curId = new java.util.concurrent.atomic.AtomicInteger()

/**
* Construct an instance of CollectObjects case class.
*
* @param function The function applied on the collection elements.
* @param inputData An expression that when evaluated returns a collection object.
* @param elementType The data type of elements in the collection.
* @param collClass The type of the resulting collection.
*/
def apply(
function: Expression => Expression,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 space indention please

inputData: Expression,
elementType: DataType,
collClass: Class[_]): CollectObjects = {
val loopValue = "CollectObjects_loopValue" + curId.getAndIncrement()
val loopIsNull = "CollectObjects_loopIsNull" + curId.getAndIncrement()
val loopVar = LambdaVariable(loopValue, loopIsNull, elementType)
val builderValue = "CollectObjects_builderValue" + curId.getAndIncrement()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all these 3 variable names can use the same curId

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Altered vals in MapObjects to share the same curId instead

CollectObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData,
collClass, builderValue)
}
}

/**
* An equivalent to the [[MapObjects]] case class but returning an ObjectType containing
* a Scala collection constructed using the associated builder, obtained by calling `newBuilder`
* on the collection's companion object.
*
* @param loopValue the name of the loop variable that used when iterate the collection, and used
* as input for the `lambdaFunction`
* @param loopIsNull the nullity of the loop variable that used when iterate the collection, and
* used as input for the `lambdaFunction`
* @param loopVarDataType the data type of the loop variable that used when iterate the collection,
* and used as input for the `lambdaFunction`
* @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function
* to handle collection elements.
* @param inputData An expression that when evaluated returns a collection object.
* @param collClass The type of the resulting collection.
* @param builderValue The name of the builder variable used to construct the resulting collection.
*/
case class CollectObjects private(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems too heavy, can we improve MapObjects to add the builder concept?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we actually can. I merged CollectObjects into MapObjects in my next commit.

loopValue: String,
loopIsNull: String,
loopVarDataType: DataType,
lambdaFunction: Expression,
inputData: Expression,
collClass: Class[_],
builderValue: String) extends Expression with NonSQLExpression {

override def nullable: Boolean = inputData.nullable

override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil

override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported")

override def dataType: DataType = ObjectType(collClass)

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val collObjectName = s"${collClass.getName}$$.MODULE$$"
val getBuilderVar = s"$collObjectName.newBuilder()"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does each collection type in Scala have newBuilder implemented?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it does. As TraversableLike documentation specifies:

Collection classes mixing in this trait ... also need to provide a method newBuilder which creates a builder for collections of the same kind.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But as I know, Range doesn't support that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to deal with this kind of cases in #16546 recently. As you can see, I also need to check if a subclass supports canBuildFrom because Range doesn't support it as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. I didn't realize Range breaks that rule.
Furthermore, it would be practically impossible to deserialize back into Range. So I guess the best way to do this will be to use a general Seq builder if the collection doesn't provide its own.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the Seq builder fallback. However, there is presently no collection that Spark supports that doesn't provide a builder. You can try it out on your branch with Range.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I see. Although we can't deserialize back to Range, as there is no problem to serialize it into internal format in SparkSQL, we still can convert the dataset to a dataframe. With RowEncoder, we can deserialize back to Row. That it what I do in #16546.

I will try if the Seq builder fallback work for that pr. Thanks.

Copy link
Contributor Author

@michalsenkyr michalsenkyr Jan 13, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I would have to modify serializerFor as you did here in order to extract the required element type.

val elementJavaType = ctx.javaType(loopVarDataType)
ctx.addMutableState("boolean", loopIsNull, "")
ctx.addMutableState(elementJavaType, loopValue, "")
val genInputData = inputData.genCode(ctx)
val genFunction = lambdaFunction.genCode(ctx)
val dataLength = ctx.freshName("dataLength")
val convertedArray = ctx.freshName("convertedArray")
val loopIndex = ctx.freshName("loopIndex")

val convertedType = ctx.boxedType(lambdaFunction.dataType)

// In RowEncoder, we use `Object` to represent Array or Seq, so we need to determine the type
// of input collection at runtime for this case.
val seq = ctx.freshName("seq")
val array = ctx.freshName("array")
val determineCollectionType = inputData.dataType match {
case ObjectType(cls) if cls == classOf[Object] =>
val seqClass = classOf[Seq[_]].getName
s"""
$seqClass $seq = null;
$elementJavaType[] $array = null;
if (${genInputData.value}.getClass().isArray()) {
$array = ($elementJavaType[]) ${genInputData.value};
} else {
$seq = ($seqClass) ${genInputData.value};
}
"""
case _ => ""
}

// The data with PythonUserDefinedType are actually stored with the data type of its sqlType.
// When we want to apply MapObjects on it, we have to use it.
val inputDataType = inputData.dataType match {
case p: PythonUserDefinedType => p.sqlType
case _ => inputData.dataType
}

val (getLength, getLoopVar) = inputDataType match {
case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
s"${genInputData.value}.size()" -> s"${genInputData.value}.apply($loopIndex)"
case ObjectType(cls) if cls.isArray =>
s"${genInputData.value}.length" -> s"${genInputData.value}[$loopIndex]"
case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
s"${genInputData.value}.size()" -> s"${genInputData.value}.get($loopIndex)"
case ArrayType(et, _) =>
s"${genInputData.value}.numElements()" -> ctx.getValue(genInputData.value, et, loopIndex)
case ObjectType(cls) if cls == classOf[Object] =>
s"$seq == null ? $array.length : $seq.size()" ->
s"$seq == null ? $array[$loopIndex] : $seq.apply($loopIndex)"
}

// Make a copy of the data if it's unsafe-backed
def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) =
s"$value instanceof ${clazz.getSimpleName}? ${value}.copy() : $value"
val genFunctionValue = lambdaFunction.dataType match {
case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value)
case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value)
case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value)
case _ => genFunction.value
}

val loopNullCheck = inputDataType match {
case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);"
// The element of primitive array will never be null.
case ObjectType(cls) if cls.isArray && cls.getComponentType.isPrimitive =>
s"$loopIsNull = false"
case _ => s"$loopIsNull = $loopValue == null;"
}

val code = s"""
${genInputData.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};

if (!${genInputData.isNull}) {
$determineCollectionType
$convertedType[] $convertedArray = null;
int $dataLength = $getLength;
${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar;
$builderValue.sizeHint($dataLength);

int $loopIndex = 0;
while ($loopIndex < $dataLength) {
$loopValue = ($elementJavaType) ($getLoopVar);
$loopNullCheck

${genFunction.code}
if (${genFunction.isNull}) {
$builderValue.$$plus$$eq(null);
} else {
$builderValue.$$plus$$eq($genFunctionValue);
}

$loopIndex += 1;
}

${ev.value} = (${collClass.getName}) $builderValue.result();
}
"""
ev.copy(code = code, isNull = genInputData.isNull)
}
}

object ExternalMapToCatalyst {
private val curId = new java.util.concurrent.atomic.AtomicInteger()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,14 +312,6 @@ class ScalaReflectionSuite extends SparkFunSuite {
ArrayType(IntegerType, containsNull = false))
val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]]
assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]]))

// Check whether conversion is skipped when using WrappedArray[_] supertype
// (would otherwise needlessly add overhead)
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
val seqDeserializer = deserializerFor[Seq[Int]]
assert(seqDeserializer.asInstanceOf[StaticInvoke].staticObject ==
scala.collection.mutable.WrappedArray.getClass)
assert(seqDeserializer.asInstanceOf[StaticInvoke].functionName == "make")
}

private val dataTypeForComplexData = dataTypeFor[ComplexData]
Expand Down