Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -307,54 +307,11 @@ object ScalaReflection extends ScalaReflection {
}
}

val array = Invoke(
MapObjects(mapFunction, getPath, dataType),
"array",
ObjectType(classOf[Array[Any]]))

val wrappedArray = StaticInvoke(
scala.collection.mutable.WrappedArray.getClass,
ObjectType(classOf[Seq[_]]),
"make",
array :: Nil)

if (localTypeOf[scala.collection.mutable.WrappedArray[_]] <:< t.erasure) {
wrappedArray
} else {
// Convert to another type using `to`
val cls = mirror.runtimeClass(t.typeSymbol.asClass)
import scala.collection.generic.CanBuildFrom
import scala.reflect.ClassTag

// Some canBuildFrom methods take an implicit ClassTag parameter
val cbfParams = try {
cls.getDeclaredMethod("canBuildFrom", classOf[ClassTag[_]])
StaticInvoke(
ClassTag.getClass,
ObjectType(classOf[ClassTag[_]]),
"apply",
StaticInvoke(
cls,
ObjectType(classOf[Class[_]]),
"getClass"
) :: Nil
) :: Nil
} catch {
case _: NoSuchMethodException => Nil
}

Invoke(
wrappedArray,
"to",
ObjectType(cls),
StaticInvoke(
cls,
ObjectType(classOf[CanBuildFrom[_, _, _]]),
"canBuildFrom",
cbfParams
) :: Nil
)
val cls = t.dealias.companion.decl(TermName("newBuilder")) match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dealias is not available in scala 2.10, @michalsenkyr can you come up with a workaround? thanks!

case NoSymbol => classOf[Seq[_]]
case _ => mirror.runtimeClass(t.typeSymbol.asClass)
}
MapObjects(mapFunction, getPath, dataType, Some(cls))

case t if t <:< localTypeOf[Map[_, _]] =>
// TODO: add walked type path for map
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.objects

import java.lang.reflect.Modifier

import scala.collection.mutable.Builder
import scala.language.existentials
import scala.reflect.ClassTag

Expand Down Expand Up @@ -429,24 +430,34 @@ object MapObjects {
* @param function The function applied on the collection elements.
* @param inputData An expression that when evaluated returns a collection object.
* @param elementType The data type of elements in the collection.
* @param customCollectionCls Class of the resulting collection (returning ObjectType)
* or None (returning ArrayType)
*/
def apply(
function: Expression => Expression,
inputData: Expression,
elementType: DataType): MapObjects = {
val loopValue = "MapObjects_loopValue" + curId.getAndIncrement()
val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement()
elementType: DataType,
customCollectionCls: Option[Class[_]] = None): MapObjects = {
val id = curId.getAndIncrement()
val loopValue = s"MapObjects_loopValue$id"
val loopIsNull = s"MapObjects_loopIsNull$id"
val loopVar = LambdaVariable(loopValue, loopIsNull, elementType)
MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData)
val builderValue = s"MapObjects_builderValue$id"
MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData,
customCollectionCls, builderValue)
}
}

/**
* Applies the given expression to every element of a collection of items, returning the result
* as an ArrayType. This is similar to a typical map operation, but where the lambda function
* is expressed using catalyst expressions.
* as an ArrayType or ObjectType. This is similar to a typical map operation, but where the lambda
* function is expressed using catalyst expressions.
*
* The type of the result is determined as follows:
* - ArrayType - when customCollectionCls is None
* - ObjectType(collection) - when customCollectionCls contains a collection class
*
* The following collection ObjectTypes are currently supported:
* The following collection ObjectTypes are currently supported on input:
* Seq, Array, ArrayData, java.util.List
*
* @param loopValue the name of the loop variable that used when iterate the collection, and used
Expand All @@ -458,13 +469,19 @@ object MapObjects {
* @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function
* to handle collection elements.
* @param inputData An expression that when evaluated returns a collection object.
* @param customCollectionCls Class of the resulting collection (returning ObjectType)
* or None (returning ArrayType)
* @param builderValue The name of the builder variable used to construct the resulting collection
* (used only when returning ObjectType)
*/
case class MapObjects private(
loopValue: String,
loopIsNull: String,
loopVarDataType: DataType,
lambdaFunction: Expression,
inputData: Expression) extends Expression with NonSQLExpression {
inputData: Expression,
customCollectionCls: Option[Class[_]],
builderValue: String) extends Expression with NonSQLExpression {

override def nullable: Boolean = inputData.nullable

Expand All @@ -474,7 +491,8 @@ case class MapObjects private(
throw new UnsupportedOperationException("Only code-generated evaluation is supported")

override def dataType: DataType =
ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable)
customCollectionCls.map(ObjectType.apply).getOrElse(
ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable))

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val elementJavaType = ctx.javaType(loopVarDataType)
Expand Down Expand Up @@ -557,15 +575,33 @@ case class MapObjects private(
case _ => s"$loopIsNull = $loopValue == null;"
}

val (initCollection, addElement, getResult): (String, String => String, String) =
customCollectionCls match {
case Some(cls) =>
// collection
val collObjectName = s"${cls.getName}$$.MODULE$$"
val getBuilderVar = s"$collObjectName.newBuilder()"

(s"""${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar;
$builderValue.sizeHint($dataLength);""",
genValue => s"$builderValue.$$plus$$eq($genValue);",
s"(${cls.getName}) $builderValue.result();")
case None =>
// array
(s"""$convertedType[] $convertedArray = null;
$convertedArray = $arrayConstructor;""",
genValue => s"$convertedArray[$loopIndex] = $genValue;",
s"new ${classOf[GenericArrayData].getName}($convertedArray);")
}

val code = s"""
${genInputData.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};

if (!${genInputData.isNull}) {
$determineCollectionType
$convertedType[] $convertedArray = null;
int $dataLength = $getLength;
$convertedArray = $arrayConstructor;
$initCollection

int $loopIndex = 0;
while ($loopIndex < $dataLength) {
Expand All @@ -574,15 +610,15 @@ case class MapObjects private(

${genFunction.code}
if (${genFunction.isNull}) {
$convertedArray[$loopIndex] = null;
${addElement("null")}
} else {
$convertedArray[$loopIndex] = $genFunctionValue;
${addElement(genFunctionValue)}
}

$loopIndex += 1;
}

${ev.value} = new ${classOf[GenericArrayData].getName}($convertedArray);
${ev.value} = $getResult
}
"""
ev.copy(code = code, isNull = genInputData.isNull)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,14 +312,6 @@ class ScalaReflectionSuite extends SparkFunSuite {
ArrayType(IntegerType, containsNull = false))
val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]]
assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]]))

// Check whether conversion is skipped when using WrappedArray[_] supertype
// (would otherwise needlessly add overhead)
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
val seqDeserializer = deserializerFor[Seq[Int]]
assert(seqDeserializer.asInstanceOf[StaticInvoke].staticObject ==
scala.collection.mutable.WrappedArray.getClass)
assert(seqDeserializer.asInstanceOf[StaticInvoke].functionName == "make")
}

private val dataTypeForComplexData = dataTypeFor[ComplexData]
Expand Down