Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.objects

import java.lang.reflect.Modifier

import scala.collection.JavaConverters._
import scala.collection.mutable.Builder
import scala.language.existentials
import scala.reflect.ClassTag
Expand Down Expand Up @@ -501,12 +502,22 @@ case class LambdaVariable(
value: String,
isNull: String,
dataType: DataType,
nullable: Boolean = true) extends LeafExpression
with Unevaluable with NonSQLExpression {
nullable: Boolean = true) extends LeafExpression with NonSQLExpression {

// Interpreted execution of `LambdaVariable` always get the 0-index element from input row.
override def eval(input: InternalRow): Any = {
assert(input.numFields == 1,
"The input row of interpreted LambdaVariable should have only 1 field.")
input.get(0, dataType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a change for this PR. Maybe we should use accessors here? This uses a matching under the hood and is slower than virtual function dispatch. Implementing this would also be useful for BoundReference for example.

Copy link
Member Author

@viirya viirya Mar 9, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean something like this?

lazy val accessor:  InternalRow => Any = dataType match {
  case IntegerType => (inputRow) => inputRow.getInt(0)
  case LongType => (inputRow) => inputRow.getLong(0)
  ...
}

override def eval(input: InternalRow): Any = accessor(input)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I do.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's spin that off into a different ticket if we want to work on it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. After this is merged, I will create another PR for it.

}

override def genCode(ctx: CodegenContext): ExprCode = {
ExprCode(code = "", value = value, isNull = if (nullable) isNull else "false")
}

// This won't be called as `genCode` is overrided, just overriding it to make
// `LambdaVariable` non-abstract.
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev
}

/**
Expand Down Expand Up @@ -599,8 +610,71 @@ case class MapObjects private(

override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil

override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
// The data with PythonUserDefinedType are actually stored with the data type of its sqlType.
// When we want to apply MapObjects on it, we have to use it.
lazy private val inputDataType = inputData.dataType match {
case p: PythonUserDefinedType => p.sqlType
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use the UserDefinedType super class here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I just noticed that this wasn't introduced by you, but please change it anyway)

case _ => inputData.dataType
}

private def executeFuncOnCollection(inputCollection: Seq[_]): Seq[_] = {
inputCollection.map { element =>
val row = InternalRow.fromSeq(Seq(element))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT reuse the row object.

lambdaFunction.eval(row)
}
}

override def eval(input: InternalRow): Any = {
val inputCollection = inputData.eval(input)

if (inputCollection == null) {
return inputCollection
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: It is slightly cleared to return null here.

}

val results = inputDataType match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't be doing this during eval. Please move this into a function val.

case ObjectType(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
executeFuncOnCollection(inputCollection.asInstanceOf[Seq[_]])
case ObjectType(cls) if cls.isArray =>
executeFuncOnCollection(inputCollection.asInstanceOf[Array[_]].toSeq)
case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
executeFuncOnCollection(inputCollection.asInstanceOf[java.util.List[_]].asScala)
case ObjectType(cls) if cls == classOf[Object] =>
if (inputCollection.getClass.isArray) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I am sorry for sounding like a broken record) But can we move this check out of the the function closure?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry...

executeFuncOnCollection(inputCollection.asInstanceOf[Array[_]].toSeq)
} else {
executeFuncOnCollection(inputCollection.asInstanceOf[Seq[_]])
}
case ArrayType(et, _) =>
executeFuncOnCollection(inputCollection.asInstanceOf[ArrayData].array)
}

customCollectionCls match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't be doing this during eval. Please move this into a function val.

case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
// Scala sequence
results.toSeq
case Some(cls) if classOf[scala.collection.Set[_]].isAssignableFrom(cls) =>
// Scala set
results.toSet
case Some(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
// Java list
if (cls == classOf[java.util.List[_]] || cls == classOf[java.util.AbstractList[_]] ||
cls == classOf[java.util.AbstractSequentialList[_]]) {
results.asJava
} else {
val builder = Try(cls.getConstructor(Integer.TYPE)).map { constructor =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you try to do the constructor lookup only once? The duplication that that will cause is ok.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I understand correctly. Please check update again.

constructor.newInstance(results.length.asInstanceOf[Object])
}.getOrElse {
cls.getConstructor().newInstance()
}.asInstanceOf[java.util.List[Any]]

results.foreach(builder.add(_))
builder
}
case None =>
// array
new GenericArrayData(results.toArray)
}
}

override def dataType: DataType =
customCollectionCls.map(ObjectType.apply).getOrElse(
Expand Down Expand Up @@ -647,13 +721,6 @@ case class MapObjects private(
case _ => ""
}

// The data with PythonUserDefinedType are actually stored with the data type of its sqlType.
// When we want to apply MapObjects on it, we have to use it.
val inputDataType = inputData.dataType match {
case p: PythonUserDefinedType => p.sqlType
case _ => inputData.dataType
}

// `MapObjects` generates a while loop to traverse the elements of the input collection. We
// need to take care of Seq and List because they may have O(n) complexity for indexed accessing
// like `list.get(1)`. Here we use Iterator to traverse Seq and List.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@

package org.apache.spark.sql.catalyst.expressions

import scala.collection.JavaConverters._

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData}
import org.apache.spark.sql.types._


Expand Down Expand Up @@ -123,4 +125,55 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(encodeUsingSerializer, null, InternalRow.fromSeq(Seq(null)))
}
}

test("SPARK-23587: MapObjects should support interpreted execution") {
val customCollectionClasses = Seq(classOf[Seq[Int]], classOf[scala.collection.Set[Int]],
classOf[java.util.List[Int]], classOf[java.util.AbstractList[Int]],
classOf[java.util.AbstractSequentialList[Int]], classOf[java.util.Vector[Int]],
classOf[java.util.Stack[Int]], null)
val function = (lambda: Expression) => Add(lambda, Literal(1))
val elementType = IntegerType
val expected = Seq(2, 3, 4)

val list = new java.util.ArrayList[Int]()
list.add(1)
list.add(2)
list.add(3)
val arrayData = new GenericArrayData(Array(1, 2, 3))
val vector = new java.util.Vector[Int]()
vector.add(1)
vector.add(2)
vector.add(3)
val stack = new java.util.Stack[Int]()
stack.add(1)
stack.add(2)
stack.add(3)

Seq(
(Seq(1, 2, 3), ObjectType(classOf[Seq[Int]])),
(list, ObjectType(classOf[java.util.List[Int]])),
(vector, ObjectType(classOf[java.util.Vector[Int]])),
(arrayData, ArrayType(IntegerType))
).foreach { case (collection, inputType) =>
val inputObject = BoundReference(0, inputType, nullable = true)

customCollectionClasses.foreach { customCollectionCls =>
val optClass = Option(customCollectionCls)
val mapObj = MapObjects(function, inputObject, elementType, true, optClass)
val row = InternalRow.fromSeq(Seq(collection))
val result = mapObj.eval(row)

customCollectionCls match {
case null =>
assert(result.asInstanceOf[ArrayData].array.toSeq == expected)
case l if classOf[java.util.List[_]].isAssignableFrom(l) =>
assert(result.asInstanceOf[java.util.List[_]].asScala.toSeq == expected)
case s if classOf[Seq[_]].isAssignableFrom(s) =>
assert(result.asInstanceOf[Seq[_]].toSeq == expected)
case s if classOf[scala.collection.Set[_]].isAssignableFrom(s) =>
assert(result.asInstanceOf[scala.collection.Set[_]] == expected.toSet)
}
}
}
}
}