Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -307,54 +307,11 @@ object ScalaReflection extends ScalaReflection {
}
}

val array = Invoke(
MapObjects(mapFunction, getPath, dataType),
"array",
ObjectType(classOf[Array[Any]]))

val wrappedArray = StaticInvoke(
scala.collection.mutable.WrappedArray.getClass,
ObjectType(classOf[Seq[_]]),
"make",
array :: Nil)

if (localTypeOf[scala.collection.mutable.WrappedArray[_]] <:< t.erasure) {
wrappedArray
} else {
// Convert to another type using `to`
val cls = mirror.runtimeClass(t.typeSymbol.asClass)
import scala.collection.generic.CanBuildFrom
import scala.reflect.ClassTag

// Some canBuildFrom methods take an implicit ClassTag parameter
val cbfParams = try {
cls.getDeclaredMethod("canBuildFrom", classOf[ClassTag[_]])
StaticInvoke(
ClassTag.getClass,
ObjectType(classOf[ClassTag[_]]),
"apply",
StaticInvoke(
cls,
ObjectType(classOf[Class[_]]),
"getClass"
) :: Nil
) :: Nil
} catch {
case _: NoSuchMethodException => Nil
}

Invoke(
wrappedArray,
"to",
ObjectType(cls),
StaticInvoke(
cls,
ObjectType(classOf[CanBuildFrom[_, _, _]]),
"canBuildFrom",
cbfParams
) :: Nil
)
val cls = t.companion.decl(TermName("newBuilder")) match {
case NoSymbol => classOf[Seq[_]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked that the builder for Seq[_] is a mutable.ListBuffer and its build result will be immutable.List. But the original deserialized type of Seq[_] is a WrappedArray.

Can we keep the original expression if we can't find newBuilder here, i.e., returning the WrappedArray?
Or use WrappedArray as the collClass instead of Seq?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rethink about this. It may not be a problem. Let me try to test it first.

case _ => mirror.runtimeClass(t.typeSymbol.asClass)
}
MapObjects(mapFunction, getPath, dataType, cls)

case t if t <:< localTypeOf[Map[_, _]] =>
// TODO: add walked type path for map
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.objects

import java.lang.reflect.Modifier

import scala.collection.mutable.Builder
import scala.language.existentials
import scala.reflect.ClassTag

Expand Down Expand Up @@ -429,24 +430,33 @@ object MapObjects {
* @param function The function applied on the collection elements.
* @param inputData An expression that when evaluated returns a collection object.
* @param elementType The data type of elements in the collection.
* @param collClass The class of the resulting collection
*/
def apply(
function: Expression => Expression,
inputData: Expression,
elementType: DataType): MapObjects = {
val loopValue = "MapObjects_loopValue" + curId.getAndIncrement()
val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement()
elementType: DataType,
collClass: Class[_] = classOf[Array[_]]): MapObjects = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can make this collClass optional.

Copy link
Contributor

@cloud-fan cloud-fan Mar 24, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and rename it to customCollectionCls

val id = curId.getAndIncrement()
val loopValue = s"MapObjects_loopValue$id"
val loopIsNull = s"MapObjects_loopIsNull$id"
val loopVar = LambdaVariable(loopValue, loopIsNull, elementType)
MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData)
val builderValue = s"MapObjects_builderValue$id"
MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData,
collClass, builderValue)
}
}

/**
* Applies the given expression to every element of a collection of items, returning the result
* as an ArrayType. This is similar to a typical map operation, but where the lambda function
* is expressed using catalyst expressions.
* as an ArrayType or ObjectType. This is similar to a typical map operation, but where the lambda
* function is expressed using catalyst expressions.
*
* The type of the result is determined as follows:
* - ArrayType - when collClass is an array class
* - ObjectType(collClass) - when collClass is a collection class
*
* The following collection ObjectTypes are currently supported:
* The following collection ObjectTypes are currently supported on input:
* Seq, Array, ArrayData, java.util.List
*
* @param loopValue the name of the loop variable that used when iterate the collection, and used
Expand All @@ -458,13 +468,18 @@ object MapObjects {
* @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function
* to handle collection elements.
* @param inputData An expression that when evaluated returns a collection object.
* @param collClass The class of the resulting collection
* @param builderValue The name of the builder variable used to construct the resulting collection
* (used only when returning ObjectType)
*/
case class MapObjects private(
loopValue: String,
loopIsNull: String,
loopVarDataType: DataType,
lambdaFunction: Expression,
inputData: Expression) extends Expression with NonSQLExpression {
inputData: Expression,
collClass: Class[_],
builderValue: String) extends Expression with NonSQLExpression {

override def nullable: Boolean = inputData.nullable

Expand All @@ -474,7 +489,8 @@ case class MapObjects private(
throw new UnsupportedOperationException("Only code-generated evaluation is supported")

override def dataType: DataType =
ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable)
if (!collClass.isArray) ObjectType(collClass)
else ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable)

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val elementJavaType = ctx.javaType(loopVarDataType)
Expand Down Expand Up @@ -557,15 +573,32 @@ case class MapObjects private(
case _ => s"$loopIsNull = $loopValue == null;"
}

val (genInit, genAssign, genResult): (String, String => String, String) =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about initCollection, addElement, getResult

if (collClass.isArray) {
// array
(s"""$convertedType[] $convertedArray = null;
$convertedArray = $arrayConstructor;""",
genValue => s"$convertedArray[$loopIndex] = $genValue;",
s"new ${classOf[GenericArrayData].getName}($convertedArray);")
} else {
// collection
val collObjectName = s"${collClass.getName}$$.MODULE$$"
val getBuilderVar = s"$collObjectName.newBuilder()"

(s"""${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar;
$builderValue.sizeHint($dataLength);""",
genValue => s"$builderValue.$$plus$$eq($genValue);",
s"(${collClass.getName}) $builderValue.result();")
}

val code = s"""
${genInputData.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};

if (!${genInputData.isNull}) {
$determineCollectionType
$convertedType[] $convertedArray = null;
int $dataLength = $getLength;
$convertedArray = $arrayConstructor;
$genInit

int $loopIndex = 0;
while ($loopIndex < $dataLength) {
Expand All @@ -574,15 +607,15 @@ case class MapObjects private(

${genFunction.code}
if (${genFunction.isNull}) {
$convertedArray[$loopIndex] = null;
${genAssign("null")}
} else {
$convertedArray[$loopIndex] = $genFunctionValue;
${genAssign(genFunctionValue)}
}

$loopIndex += 1;
}

${ev.value} = new ${classOf[GenericArrayData].getName}($convertedArray);
${ev.value} = $genResult
}
"""
ev.copy(code = code, isNull = genInputData.isNull)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,14 +312,6 @@ class ScalaReflectionSuite extends SparkFunSuite {
ArrayType(IntegerType, containsNull = false))
val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]]
assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]]))

// Check whether conversion is skipped when using WrappedArray[_] supertype
// (would otherwise needlessly add overhead)
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
val seqDeserializer = deserializerFor[Seq[Int]]
assert(seqDeserializer.asInstanceOf[StaticInvoke].staticObject ==
scala.collection.mutable.WrappedArray.getClass)
assert(seqDeserializer.asInstanceOf[StaticInvoke].functionName == "make")
}

private val dataTypeForComplexData = dataTypeFor[ComplexData]
Expand Down