From fbf92c586d19ee581ced8802dd1b9f6f6f73e220 Mon Sep 17 00:00:00 2001 From: Michal Senkyr Date: Sun, 16 Apr 2017 17:54:40 +0200 Subject: [PATCH 1/2] Add arbitrary Java List serialization/deserialization support --- .../spark/sql/catalyst/ScalaReflection.scala | 18 +++++++++++------- .../expressions/objects/objects.scala | 19 +++++++++++++++++-- .../sql/catalyst/ScalaReflectionSuite.scala | 10 ++++++++++ .../org/apache/spark/sql/SQLImplicits.scala | 4 ++++ 4 files changed, 42 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 87130532c89b..ac1f27160117 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -307,7 +307,7 @@ object ScalaReflection extends ScalaReflection { Invoke(arrayData, primitiveMethod, arrayCls, returnNullable = false) } - case t if t <:< localTypeOf[Seq[_]] => + case t if t <:< localTypeOf[Seq[_]] || t <:< localTypeOf[java.util.List[_]] => val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, elementNullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) @@ -324,10 +324,14 @@ object ScalaReflection extends ScalaReflection { } } - val companion = t.normalize.typeSymbol.companionSymbol.typeSignature - val cls = companion.declaration(newTermName("newBuilder")) match { - case NoSymbol => classOf[Seq[_]] - case _ => mirror.runtimeClass(t.typeSymbol.asClass) + val cls = if (t <:< localTypeOf[java.util.List[_]]) { + mirror.runtimeClass(t.typeSymbol.asClass) + } else { + val companion = t.normalize.typeSymbol.companionSymbol.typeSignature + companion.declaration(newTermName("newBuilder")) match { + case NoSymbol => classOf[Seq[_]] + case _ => mirror.runtimeClass(t.typeSymbol.asClass) + } } UnresolvedMapObjects(mapFunction, getPath, Some(cls)) @@ -494,7 +498,7 @@ object ScalaReflection extends ScalaReflection { // Since List[_] also belongs to localTypeOf[Product], we put this case before // "case t if definedByConstructorParams(t)" to make sure it will match to the // case "localTypeOf[Seq[_]]" - case t if t <:< localTypeOf[Seq[_]] => + case t if t <:< localTypeOf[Seq[_]] || t <:< localTypeOf[java.util.List[_]] => val TypeRef(_, _, Seq(elementType)) = t toCatalystArray(inputObject, elementType) @@ -712,7 +716,7 @@ object ScalaReflection extends ScalaReflection { val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, nullable) = schemaFor(elementType) Schema(ArrayType(dataType, containsNull = nullable), nullable = true) - case t if t <:< localTypeOf[Seq[_]] => + case t if t <:< localTypeOf[Seq[_]] || t <:< localTypeOf[java.util.List[_]] => val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, nullable) = schemaFor(elementType) Schema(ArrayType(dataType, containsNull = nullable), nullable = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 1a202ecf745c..ceba3f0e778c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -22,6 +22,7 @@ import java.lang.reflect.Modifier import scala.collection.mutable.Builder import scala.language.existentials import scala.reflect.ClassTag +import scala.util.Try import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.serializer._ @@ -597,8 +598,8 @@ case class MapObjects private( val (initCollection, addElement, getResult): (String, String => String, String) = customCollectionCls match { - case Some(cls) => - // collection + case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) => + // Scala sequence val getBuilder = s"${cls.getName}$$.MODULE$$.newBuilder()" val builder = ctx.freshName("collectionBuilder") ( @@ -609,6 +610,20 @@ case class MapObjects private( genValue => s"$builder.$$plus$$eq($genValue);", s"(${cls.getName}) $builder.result();" ) + case Some(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => + // Java list + val builder = ctx.freshName("collectionBuilder") + ( + if (cls == classOf[java.util.List[_]] || cls == classOf[java.util.AbstractList[_]] || + cls == classOf[java.util.AbstractSequentialList[_]]) { + s"${cls.getName} $builder = new java.util.ArrayList($dataLength);" + } else { + val param = Try(cls.getConstructor(Integer.TYPE)).map(_ => dataLength).getOrElse("") + s"${cls.getName} $builder = new ${cls.getName}($param);" + }, + genValue => s"$builder.add($genValue);", + s"$builder;" + ) case None => // array ( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 70ad064f93eb..ceceffe85e41 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -314,6 +314,16 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]])) } + test("serialize and deserialize arbitrary java list types") { + import java.util.ArrayList + val arrayListSerializer = serializerFor[ArrayList[Int]](BoundReference( + 0, ObjectType(classOf[ArrayList[Int]]), nullable = false)) + assert(arrayListSerializer.dataType.head.dataType == + ArrayType(IntegerType, containsNull = false)) + val arrayListDeserializer = deserializerFor[ArrayList[Int]] + assert(arrayListDeserializer.dataType == ObjectType(classOf[ArrayList[_]])) + } + private val dataTypeForComplexData = dataTypeFor[ComplexData] private val typeOfComplexData = typeOf[ComplexData] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 17671ea8685b..431d87f89539 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -166,6 +166,10 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits { /** @since 2.2.0 */ implicit def newSequenceEncoder[T <: Seq[_] : TypeTag]: Encoder[T] = ExpressionEncoder() + /** @since 2.2.0 */ + implicit def newJavaListEncoder[T <: java.util.List[_] : TypeTag]: Encoder[T] = + ExpressionEncoder() + // Arrays /** @since 1.6.1 */ From 881e6363ca6a8545e45f3e023f32762afd60943e Mon Sep 17 00:00:00 2001 From: Michal Senkyr Date: Sat, 10 Jun 2017 20:02:49 +0200 Subject: [PATCH 2/2] Add specific Java List support to JavaTypeInference Remove specific Java List support from ScalaReflection Remove implicit encoder for Java Lists Add relevant tests to JavaDatasetSuite Remove tests from ScalaReflectionSuite and DatasetPrimitiveSuite --- .../sql/catalyst/JavaTypeInference.scala | 15 ++--- .../spark/sql/catalyst/ScalaReflection.scala | 18 +++--- .../sql/catalyst/ScalaReflectionSuite.scala | 10 --- .../org/apache/spark/sql/SQLImplicits.scala | 4 -- .../apache/spark/sql/JavaDatasetSuite.java | 61 +++++++++++++++++++ 5 files changed, 73 insertions(+), 35 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 86a73a319ec3..7683ee7074e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -267,16 +267,11 @@ object JavaTypeInference { case c if listType.isAssignableFrom(typeToken) => val et = elementType(typeToken) - val array = - Invoke( - MapObjects( - p => deserializerFor(et, Some(p)), - getPath, - inferDataType(et)._1), - "array", - ObjectType(classOf[Array[Any]])) - - StaticInvoke(classOf[java.util.Arrays], ObjectType(c), "asList", array :: Nil) + MapObjects( + p => deserializerFor(et, Some(p)), + getPath, + inferDataType(et)._1, + customCollectionCls = Some(c)) case _ if mapType.isAssignableFrom(typeToken) => val (keyType, valueType) = mapKeyValueType(typeToken) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index ac1f27160117..87130532c89b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -307,7 +307,7 @@ object ScalaReflection extends ScalaReflection { Invoke(arrayData, primitiveMethod, arrayCls, returnNullable = false) } - case t if t <:< localTypeOf[Seq[_]] || t <:< localTypeOf[java.util.List[_]] => + case t if t <:< localTypeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, elementNullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) @@ -324,14 +324,10 @@ object ScalaReflection extends ScalaReflection { } } - val cls = if (t <:< localTypeOf[java.util.List[_]]) { - mirror.runtimeClass(t.typeSymbol.asClass) - } else { - val companion = t.normalize.typeSymbol.companionSymbol.typeSignature - companion.declaration(newTermName("newBuilder")) match { - case NoSymbol => classOf[Seq[_]] - case _ => mirror.runtimeClass(t.typeSymbol.asClass) - } + val companion = t.normalize.typeSymbol.companionSymbol.typeSignature + val cls = companion.declaration(newTermName("newBuilder")) match { + case NoSymbol => classOf[Seq[_]] + case _ => mirror.runtimeClass(t.typeSymbol.asClass) } UnresolvedMapObjects(mapFunction, getPath, Some(cls)) @@ -498,7 +494,7 @@ object ScalaReflection extends ScalaReflection { // Since List[_] also belongs to localTypeOf[Product], we put this case before // "case t if definedByConstructorParams(t)" to make sure it will match to the // case "localTypeOf[Seq[_]]" - case t if t <:< localTypeOf[Seq[_]] || t <:< localTypeOf[java.util.List[_]] => + case t if t <:< localTypeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t toCatalystArray(inputObject, elementType) @@ -716,7 +712,7 @@ object ScalaReflection extends ScalaReflection { val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, nullable) = schemaFor(elementType) Schema(ArrayType(dataType, containsNull = nullable), nullable = true) - case t if t <:< localTypeOf[Seq[_]] || t <:< localTypeOf[java.util.List[_]] => + case t if t <:< localTypeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, nullable) = schemaFor(elementType) Schema(ArrayType(dataType, containsNull = nullable), nullable = true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index ceceffe85e41..70ad064f93eb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -314,16 +314,6 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]])) } - test("serialize and deserialize arbitrary java list types") { - import java.util.ArrayList - val arrayListSerializer = serializerFor[ArrayList[Int]](BoundReference( - 0, ObjectType(classOf[ArrayList[Int]]), nullable = false)) - assert(arrayListSerializer.dataType.head.dataType == - ArrayType(IntegerType, containsNull = false)) - val arrayListDeserializer = deserializerFor[ArrayList[Int]] - assert(arrayListDeserializer.dataType == ObjectType(classOf[ArrayList[_]])) - } - private val dataTypeForComplexData = dataTypeFor[ComplexData] private val typeOfComplexData = typeOf[ComplexData] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 431d87f89539..17671ea8685b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -166,10 +166,6 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits { /** @since 2.2.0 */ implicit def newSequenceEncoder[T <: Seq[_] : TypeTag]: Encoder[T] = ExpressionEncoder() - /** @since 2.2.0 */ - implicit def newJavaListEncoder[T <: java.util.List[_] : TypeTag]: Encoder[T] = - ExpressionEncoder() - // Arrays /** @since 1.6.1 */ diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 3ba37addfc8b..4ca3b6406a32 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -1399,4 +1399,65 @@ public void testSerializeNull() { ds1.map((MapFunction) b -> b, encoder); Assert.assertEquals(beans, ds2.collectAsList()); } + + @Test + public void testSpecificLists() { + SpecificListsBean bean = new SpecificListsBean(); + ArrayList arrayList = new ArrayList<>(); + arrayList.add(1); + bean.setArrayList(arrayList); + LinkedList linkedList = new LinkedList<>(); + linkedList.add(1); + bean.setLinkedList(linkedList); + bean.setList(Collections.singletonList(1)); + List beans = Collections.singletonList(bean); + Dataset dataset = + spark.createDataset(beans, Encoders.bean(SpecificListsBean.class)); + Assert.assertEquals(beans, dataset.collectAsList()); + } + + public static class SpecificListsBean implements Serializable { + private ArrayList arrayList; + private LinkedList linkedList; + private List list; + + public ArrayList getArrayList() { + return arrayList; + } + + public void setArrayList(ArrayList arrayList) { + this.arrayList = arrayList; + } + + public LinkedList getLinkedList() { + return linkedList; + } + + public void setLinkedList(LinkedList linkedList) { + this.linkedList = linkedList; + } + + public List getList() { + return list; + } + + public void setList(List list) { + this.list = list; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SpecificListsBean that = (SpecificListsBean) o; + return Objects.equal(arrayList, that.arrayList) && + Objects.equal(linkedList, that.linkedList) && + Objects.equal(list, that.list); + } + + @Override + public int hashCode() { + return Objects.hashCode(arrayList, linkedList, list); + } + } }