From c634aa421a7f09eafbebf179e9ba056f498d0516 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 30 Nov 2015 15:14:56 +0800 Subject: [PATCH 1/6] support java bean encoder --- .../scala/org/apache/spark/sql/Encoder.scala | 2 + .../sql/catalyst/JavaTypeInference.scala | 309 ++++++++++++++++-- .../spark/sql/catalyst/ScalaReflection.scala | 3 +- .../catalyst/encoders/ExpressionEncoder.scala | 21 +- .../sql/catalyst/expressions/objects.scala | 39 ++- .../spark/sql/catalyst/trees/TreeNode.scala | 27 +- .../sql/catalyst/util/ArrayBasedMapData.scala | 5 + .../apache/spark/sql/JavaDatasetSuite.java | 128 ++++++++ 8 files changed, 508 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 03aa25eda807..9bb755fd34ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -97,6 +97,8 @@ object Encoders { */ def STRING: Encoder[java.lang.String] = ExpressionEncoder() + def bean[T](beanCls: Class[T]): Encoder[T] = ExpressionEncoder(beanCls) + /** * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. * This encoder maps T into a single byte array (binary) field. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 7d4cfbe6faec..01c185b6a952 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -17,14 +17,20 @@ package org.apache.spark.sql.catalyst -import java.beans.Introspector +import java.beans.{PropertyDescriptor, Introspector} import java.lang.{Iterable => JIterable} -import java.util.{Iterator => JIterator, Map => JMap} +import java.util.{Iterator => JIterator, Map => JMap, List => JList} import scala.language.existentials import com.google.common.reflect.TypeToken + import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils} +import org.apache.spark.unsafe.types.UTF8String + /** * Type-inference utilities for POJOs and Java collections. @@ -33,13 +39,14 @@ object JavaTypeInference { private val iterableType = TypeToken.of(classOf[JIterable[_]]) private val mapType = TypeToken.of(classOf[JMap[_, _]]) + private val listType = TypeToken.of(classOf[JList[_]]) private val iteratorReturnType = classOf[JIterable[_]].getMethod("iterator").getGenericReturnType private val nextReturnType = classOf[JIterator[_]].getMethod("next").getGenericReturnType private val keySetReturnType = classOf[JMap[_, _]].getMethod("keySet").getGenericReturnType private val valuesReturnType = classOf[JMap[_, _]].getMethod("values").getGenericReturnType /** - * Infers the corresponding SQL data type of a JavaClean class. + * Infers the corresponding SQL data type of a JavaBean class. * @param beanClass Java type * @return (SQL data type, nullable) */ @@ -58,6 +65,8 @@ object JavaTypeInference { (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) case c: Class[_] if c == classOf[java.lang.String] => (StringType, true) + case c: Class[_] if c == classOf[Array[Byte]] => (BinaryType, true) + case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false) case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false) case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false) @@ -87,31 +96,293 @@ object JavaTypeInference { (ArrayType(dataType, nullable), true) case _ if mapType.isAssignableFrom(typeToken) => - val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]] - val mapSupertype = typeToken2.getSupertype(classOf[JMap[_, _]]) - val keyType = elementType(mapSupertype.resolveType(keySetReturnType)) - val valueType = elementType(mapSupertype.resolveType(valuesReturnType)) + val (keyType, valueType) = mapKeyValueType(typeToken) val (keyDataType, _) = inferDataType(keyType) val (valueDataType, nullable) = inferDataType(valueType) (MapType(keyDataType, valueDataType, nullable), true) - case _ => - val beanInfo = Introspector.getBeanInfo(typeToken.getRawType) - val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") - val fields = properties.map { property => - val returnType = typeToken.method(property.getReadMethod).getReturnType - val (dataType, nullable) = inferDataType(returnType) - new StructField(property.getName, dataType, nullable) + case other => + val properties = getJavaBeanProperties(other) + if (properties.length > 0) { + val fields = properties.map { property => + val returnType = typeToken.method(property.getReadMethod).getReturnType + val (dataType, nullable) = inferDataType(returnType) + new StructField(property.getName, dataType, nullable) + } + (new StructType(fields), true) + } else { + throw new UnsupportedOperationException(s"Cannot infer data type for ${other.getName}") } - (new StructType(fields), true) } } + private def getJavaBeanProperties(beanClass: Class[_]): Array[PropertyDescriptor] = { + val beanInfo = Introspector.getBeanInfo(beanClass) + beanInfo.getPropertyDescriptors + .filter(p => p.getReadMethod != null && p.getWriteMethod != null) + } + private def elementType(typeToken: TypeToken[_]): TypeToken[_] = { val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JIterable[_]]] - val iterableSupertype = typeToken2.getSupertype(classOf[JIterable[_]]) - val iteratorType = iterableSupertype.resolveType(iteratorReturnType) - val itemType = iteratorType.resolveType(nextReturnType) - itemType + val iterableSuperType = typeToken2.getSupertype(classOf[JIterable[_]]) + val iteratorType = iterableSuperType.resolveType(iteratorReturnType) + iteratorType.resolveType(nextReturnType) + } + + private def mapKeyValueType(typeToken: TypeToken[_]): (TypeToken[_], TypeToken[_]) = { + val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]] + val mapSuperType = typeToken2.getSupertype(classOf[JMap[_, _]]) + val keyType = elementType(mapSuperType.resolveType(keySetReturnType)) + val valueType = elementType(mapSuperType.resolveType(valuesReturnType)) + keyType -> valueType + } + + private def inferExternalType(cls: Class[_]): DataType = cls match { + case c if c == java.lang.Boolean.TYPE => BooleanType + case c if c == java.lang.Byte.TYPE => ByteType + case c if c == java.lang.Short.TYPE => ShortType + case c if c == java.lang.Integer.TYPE => IntegerType + case c if c == java.lang.Long.TYPE => LongType + case c if c == java.lang.Float.TYPE => FloatType + case c if c == java.lang.Double.TYPE => DoubleType + case c if c == classOf[Array[Byte]] => BinaryType + case _ => ObjectType(cls) + } + + def constructorFor(beanClass: Class[_]): Expression = { + constructorFor(TypeToken.of(beanClass), None) + } + + private def constructorFor(typeToken: TypeToken[_], path: Option[Expression]): Expression = { + /** Returns the current path with a sub-field extracted. */ + def addToPath(part: String): Expression = path + .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) + .getOrElse(UnresolvedAttribute(part)) + + /** Returns the current path or `BoundReference`. */ + def getPath: Expression = path.getOrElse(BoundReference(0, inferDataType(typeToken)._1, true)) + + typeToken.getRawType match { + case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath + + case c if c == classOf[java.lang.Short] => + NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + case c if c == classOf[java.lang.Integer] => + NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + case c if c == classOf[java.lang.Long] => + NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + case c if c == classOf[java.lang.Double] => + NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + case c if c == classOf[java.lang.Byte] => + NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + case c if c == classOf[java.lang.Float] => + NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + case c if c == classOf[java.lang.Boolean] => + NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + + case c if c == classOf[java.sql.Date] => + StaticInvoke( + DateTimeUtils, + ObjectType(c), + "toJavaDate", + getPath :: Nil, + propagateNull = true) + + case c if c == classOf[java.sql.Timestamp] => + StaticInvoke( + DateTimeUtils, + ObjectType(c), + "toJavaTimestamp", + getPath :: Nil, + propagateNull = true) + + case c if c == classOf[java.lang.String] => + Invoke(getPath, "toString", ObjectType(classOf[String])) + + case c if c == classOf[java.math.BigDecimal] => + Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + + case c if c.isArray => + val elementType = c.getComponentType + val primitiveMethod = elementType match { + case c if c == java.lang.Boolean.TYPE => Some("toBooleanArray") + case c if c == java.lang.Byte.TYPE => Some("toByteArray") + case c if c == java.lang.Short.TYPE => Some("toShortArray") + case c if c == java.lang.Integer.TYPE => Some("toIntArray") + case c if c == java.lang.Long.TYPE => Some("toLongArray") + case c if c == java.lang.Float.TYPE => Some("toFloatArray") + case c if c == java.lang.Double.TYPE => Some("toDoubleArray") + case _ => None + } + + primitiveMethod.map { method => + Invoke(getPath, method, ObjectType(c)) + }.getOrElse { + Invoke( + MapObjects( + p => constructorFor(typeToken.getComponentType, Some(p)), + getPath, + inferDataType(elementType)._1), + "array", + ObjectType(c)) + } + + case c if listType.isAssignableFrom(typeToken) => + val et = elementType(typeToken) + val array = + Invoke( + MapObjects( + p => constructorFor(et, Some(p)), + getPath, + inferDataType(et)._1), + "array", + ObjectType(classOf[Array[Any]])) + + StaticInvoke(classOf[java.util.Arrays], ObjectType(c), "asList", array :: Nil) + + case _ if mapType.isAssignableFrom(typeToken) => + val (keyType, valueType) = mapKeyValueType(typeToken) + val keyDataType = inferDataType(keyType)._1 + val valueDataType = inferDataType(valueType)._1 + + val keyData = + Invoke( + MapObjects( + p => constructorFor(keyType, Some(p)), + Invoke(getPath, "keyArray", ArrayType(keyDataType)), + keyDataType), + "array", + ObjectType(classOf[Array[Any]])) + + val valueData = + Invoke( + MapObjects( + p => constructorFor(valueType, Some(p)), + Invoke(getPath, "valueArray", ArrayType(valueDataType)), + valueDataType), + "array", + ObjectType(classOf[Array[Any]])) + + StaticInvoke( + ArrayBasedMapData, + ObjectType(classOf[JMap[_, _]]), + "toJavaMap", + keyData :: valueData :: Nil) + + case other => + val properties = getJavaBeanProperties(other) + assert(properties.length > 0) + + val setters = properties.map { p => + val fieldName = p.getName + val fieldType = typeToken.method(p.getReadMethod).getReturnType + p.getWriteMethod.getName -> constructorFor(fieldType, Some(addToPath(fieldName))) + }.toMap + + val newInstance = NewInstance(other, Nil, propagateNull = false, ObjectType(other)) + val result = InitializeJavaBean(newInstance, setters) + + if (path.nonEmpty) { + expressions.If( + IsNull(getPath), + expressions.Literal.create(null, ObjectType(other)), + result + ) + } else { + result + } + } + } + + def extractorsFor(inputObject: Expression, beanClass: Class[_]): CreateNamedStruct = { + extractorFor(inputObject, TypeToken.of(beanClass)).asInstanceOf[CreateNamedStruct] + } + + private def extractorFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = { + + def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = { + val externalType = inferExternalType(elementType.getRawType) + val (dataType, nullable) = inferDataType(elementType) + if (ScalaReflection.isNativeType(dataType)) { + NewInstance( + classOf[GenericArrayData], + input :: Nil, + dataType = ArrayType(dataType, nullable)) + } else { + MapObjects(extractorFor(_, elementType), input, externalType) + } + } + + if (!inputObject.dataType.isInstanceOf[ObjectType]) { + inputObject + } else { + typeToken.getRawType match { + case c if c == classOf[String] => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil) + + case c if c == classOf[java.sql.Timestamp] => + StaticInvoke( + DateTimeUtils, + TimestampType, + "fromJavaTimestamp", + inputObject :: Nil) + + case c if c == classOf[java.sql.Date] => + StaticInvoke( + DateTimeUtils, + DateType, + "fromJavaDate", + inputObject :: Nil) + + case c if c == classOf[java.math.BigDecimal] => + StaticInvoke( + Decimal, + DecimalType.SYSTEM_DEFAULT, + "apply", + inputObject :: Nil) + + case c if c == classOf[java.lang.Boolean] => + Invoke(inputObject, "booleanValue", BooleanType) + case c if c == classOf[java.lang.Byte] => + Invoke(inputObject, "byteValue", ByteType) + case c if c == classOf[java.lang.Short] => + Invoke(inputObject, "shortValue", ShortType) + case c if c == classOf[java.lang.Integer] => + Invoke(inputObject, "intValue", IntegerType) + case c if c == classOf[java.lang.Long] => + Invoke(inputObject, "longValue", LongType) + case c if c == classOf[java.lang.Float] => + Invoke(inputObject, "floatValue", FloatType) + case c if c == classOf[java.lang.Double] => + Invoke(inputObject, "doubleValue", DoubleType) + + case _ if typeToken.isArray => + toCatalystArray(inputObject, typeToken.getComponentType) + + case _ if listType.isAssignableFrom(typeToken) => + toCatalystArray(inputObject, elementType(typeToken)) + + case _ if mapType.isAssignableFrom(typeToken) => + throw new UnsupportedOperationException("map type is not supported currently") + + case other => + val properties = getJavaBeanProperties(other) + assert(properties.length > 0) + + CreateNamedStruct(properties.flatMap { p => + val fieldName = p.getName + val fieldType = typeToken.method(p.getReadMethod).getReturnType + val fieldValue = Invoke( + inputObject, + p.getReadMethod.getName, + inferExternalType(fieldType.getRawType)) + expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil + }) + } + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index d133ad3f0d89..66719667ab6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -18,9 +18,8 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis.{UnresolvedExtractValue, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, ArrayData, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 0c10a56c555f..6ef9159f7d82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -28,8 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtract import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.{JavaTypeInference, InternalRow, ScalaReflection} import org.apache.spark.sql.types.{StructField, ObjectType, StructType} /** @@ -67,6 +66,22 @@ object ExpressionEncoder { ClassTag[T](cls)) } + def apply[T](beanClass: Class[T]): ExpressionEncoder[T] = { + val schema = JavaTypeInference.inferDataType(beanClass)._1 + assert(schema.isInstanceOf[StructType]) + + val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true) + val toRowExpression = JavaTypeInference.extractorsFor(inputObject, beanClass) + val fromRowExpression = JavaTypeInference.constructorFor(beanClass) + + new ExpressionEncoder[T]( + schema.asInstanceOf[StructType], + flat = false, + toRowExpression.flatten, + fromRowExpression, + ClassTag[T](beanClass)) + } + /** * Given a set of N encoders, constructs a new encoder that produce objects as items in an * N-tuple. Note that these encoders should be unresolved so that information about @@ -215,7 +230,7 @@ case class ExpressionEncoder[T]( */ def assertUnresolved(): Unit = { (fromRowExpression +: toRowExpressions).foreach(_.foreach { - case a: AttributeReference => + case a: AttributeReference if a.name != "loopVar" => sys.error(s"Unresolved encoder expected, but $a was found.") case _ => }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 62d09f0f5510..501781cf7bea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -346,7 +346,8 @@ case class LambdaVariable(value: String, isNull: String, dataType: DataType) ext * as an ArrayType. This is similar to a typical map operation, but where the lambda function * is expressed using catalyst expressions. * - * The following collection ObjectTypes are currently supported: Seq, Array, ArrayData + * The following collection ObjectTypes are currently supported: + * Seq, Array, ArrayData, java.util.List * * @param function A function that returns an expression, given an attribute that can be used * to access the current value. This is does as a lambda function so that @@ -386,6 +387,8 @@ case class MapObjects( (".size()", (i: String) => s".apply($i)", false) case ObjectType(cls) if cls.isArray => (".length", (i: String) => s"[$i]", false) + case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => + (".size()", (i: String) => s".get($i)", false) case ArrayType(t, _) => val (sqlType, primitiveElement) = t match { case m: MapType => (m, false) @@ -596,3 +599,37 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B override def dataType: DataType = ObjectType(tag.runtimeClass) } + +case class InitializeJavaBean(n: NewInstance, setters: Map[String, Expression]) + extends Expression { + + override def nullable: Boolean = false + override def children: Seq[Expression] = n +: setters.values.toSeq + override def dataType: DataType = n.dataType + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val instance = n.gen(ctx) + + val initialize = setters.map { + case (setterMethod, fieldValue) => + val fieldGen = fieldValue.gen(ctx) + s""" + ${fieldGen.code} + ${instance.value}.$setterMethod(${fieldGen.value}); + """ + } + + ev.isNull = instance.isNull + ev.value = instance.value + + s""" + ${instance.code} + if (!${instance.isNull}) { + ${initialize.mkString("\n")} + } + """ + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 35f087baccde..0650611c3916 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.trees +import scala.collection.Map + import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.types.{StructType, DataType} @@ -191,6 +193,19 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case nonChild: AnyRef => nonChild case null => null } + case m: Map[_, _] => m.mapValues { + case arg: TreeNode[_] if containsChild(arg) => + val newChild = remainingNewChildren.remove(0) + val oldChild = remainingOldChildren.remove(0) + if (newChild fastEquals oldChild) { + oldChild + } else { + changed = true + newChild + } + case nonChild: AnyRef => nonChild + case null => null + }.view.force case arg: TreeNode[_] if containsChild(arg) => val newChild = remainingNewChildren.remove(0) val oldChild = remainingOldChildren.remove(0) @@ -262,7 +277,17 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } else { Some(arg) } - case m: Map[_, _] => m + case m: Map[_, _] => m.mapValues { + case arg: TreeNode[_] if containsChild(arg) => + val newChild = nextOperation(arg.asInstanceOf[BaseType], rule) + if (!(newChild fastEquals arg)) { + changed = true + newChild + } else { + arg + } + case other => other + }.view.force case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { case arg: TreeNode[_] if containsChild(arg) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala index 70b028d2b3f7..d85b72ed83de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala @@ -70,4 +70,9 @@ object ArrayBasedMapData { def toScalaMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = { keys.zip(values).toMap } + + def toJavaMap(keys: Array[Any], values: Array[Any]): java.util.Map[Any, Any] = { + import scala.collection.JavaConverters._ + keys.zip(values).toMap.asJava + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 67a3190cb7d4..438fa4decce0 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -37,6 +37,7 @@ import org.apache.spark.sql.GroupedDataset; import org.apache.spark.sql.expressions.Aggregator; import org.apache.spark.sql.test.TestSQLContext; +import org.apache.spark.sql.catalyst.encoders.OuterScopes; import static org.apache.spark.sql.functions.*; @@ -506,4 +507,131 @@ public void testJavaEncoderErrorMessageForPrivateClass() { public void testKryoEncoderErrorMessageForPrivateClass() { Encoders.kryo(PrivateClassTest.class); } + + public class SimpleJavaBean implements Serializable { + private boolean a; + private int b; + private byte[] c; + private String[] d; + private List e; + + public boolean isA() { + return a; + } + + public void setA(boolean a) { + this.a = a; + } + + public int getB() { + return b; + } + + public void setB(int b) { + this.b = b; + } + + public byte[] getC() { + return c; + } + + public void setC(byte[] c) { + this.c = c; + } + + public String[] getD() { + return d; + } + + public void setD(String[] d) { + this.d = d; + } + + public List getE() { + return e; + } + + public void setE(List e) { + this.e = e; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + SimpleJavaBean that = (SimpleJavaBean) o; + + if (a != that.a) return false; + if (b != that.b) return false; + if (!Arrays.equals(c, that.c)) return false; + if (!Arrays.equals(d, that.d)) return false; + return e.equals(that.e); + } + + @Override + public int hashCode() { + int result = (a ? 1 : 0); + result = 31 * result + b; + result = 31 * result + Arrays.hashCode(c); + result = 31 * result + Arrays.hashCode(d); + result = 31 * result + e.hashCode(); + return result; + } + } + + public class NestedJavaBean implements Serializable { + private SimpleJavaBean a; + + public SimpleJavaBean getA() { + return a; + } + + public void setA(SimpleJavaBean a) { + this.a = a; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + NestedJavaBean that = (NestedJavaBean) o; + + return a.equals(that.a); + } + + @Override + public int hashCode() { + return a.hashCode(); + } + } + + @Test + public void testJavaBeanEncoder() { + OuterScopes.addOuterScope(this); + SimpleJavaBean obj1 = new SimpleJavaBean(); + obj1.setA(true); + obj1.setB(3); + obj1.setC(new byte[]{1}); + obj1.setD(new String[]{"hello"}); + obj1.setE(Arrays.asList("a", "b")); + SimpleJavaBean obj2 = new SimpleJavaBean(); + obj2.setA(false); + obj2.setB(30); + obj2.setC(new byte[]{2}); + obj1.setD(new String[]{"world"}); + obj2.setE(Arrays.asList("x", "y")); + + List data = Arrays.asList(obj1, obj2); + Dataset ds = context.createDataset(data, Encoders.bean(SimpleJavaBean.class)); + Assert.assertEquals(data, ds.collectAsList()); + + NestedJavaBean obj3 = new NestedJavaBean(); + obj3.setA(obj1); + + List data2 = Arrays.asList(obj3); + Dataset ds2 = context.createDataset(data2, Encoders.bean(NestedJavaBean.class)); + Assert.assertEquals(data2, ds2.collectAsList()); + } } From e8527102e99d842e7910645f9ea4e6ac513acd08 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 30 Nov 2015 16:11:20 +0800 Subject: [PATCH 2/6] fix test --- .../org/apache/spark/sql/JavaDataFrameSuite.java | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 8e0b2dbca4a9..8027f2f72d28 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -136,17 +136,33 @@ public double getA() { return a; } + public void setA(double a) { + this.a = a; + } + public Integer[] getB() { return b; } + public void setB(Integer[] b) { + this.b = b; + } + public Map getC() { return c; } + public void setC(Map c) { + this.c = c; + } + public List getD() { return d; } + + public void setD(List d) { + this.d = d; + } } void validateDataFrameWithBeans(Bean bean, DataFrame df) { From 7940ec74f7fb1d06fbab703068ed1a5a3dbd204c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 30 Nov 2015 19:51:28 +0800 Subject: [PATCH 3/6] add doc and todo --- .../scala/org/apache/spark/sql/Encoder.scala | 9 +++++++- .../sql/catalyst/JavaTypeInference.scala | 22 ++++++++++++++++--- .../catalyst/encoders/ExpressionEncoder.scala | 4 ++-- .../sql/catalyst/expressions/objects.scala | 5 ++++- .../sql/catalyst/util/GenericArrayData.scala | 3 +++ 5 files changed, 36 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 9bb755fd34ad..f0c8dcee0557 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -97,7 +97,14 @@ object Encoders { */ def STRING: Encoder[java.lang.String] = ExpressionEncoder() - def bean[T](beanCls: Class[T]): Encoder[T] = ExpressionEncoder(beanCls) + /** + * Creates an encoder for Java Bean of type T. + * + * T must be publicly accessible. + * + * @since 1.6.0 + */ + def bean[T](beanClass: Class[T]): Encoder[T] = ExpressionEncoder(beanClass) /** * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 01c185b6a952..7c32690a3ed3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -137,6 +137,13 @@ object JavaTypeInference { keyType -> valueType } + /** + * Returns the Spark SQL DataType for a given java class. Where this is not an exact mapping + * to a native type, an ObjectType is returned. + * + * Unlike `inferDataType`, this function doesn't do any massaging of types into the Spark SQL type + * system. As a result, ObjectType will be returned for things like boxed Integers. + */ private def inferExternalType(cls: Class[_]): DataType = cls match { case c if c == java.lang.Boolean.TYPE => BooleanType case c if c == java.lang.Byte.TYPE => ByteType @@ -149,6 +156,12 @@ object JavaTypeInference { case _ => ObjectType(cls) } + /** + * Returns an expression that can be used to construct an object of java bean `T` given an input + * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes + * of the same name as the constructor arguments. Nested classes will have their fields accessed + * using UnresolvedExtractValue. + */ def constructorFor(beanClass: Class[_]): Expression = { constructorFor(TypeToken.of(beanClass), None) } @@ -294,14 +307,17 @@ object JavaTypeInference { } } - def extractorsFor(inputObject: Expression, beanClass: Class[_]): CreateNamedStruct = { + /** + * Returns expressions for extracting all the fields from the given type. + */ + def extractorsFor(beanClass: Class[_]): CreateNamedStruct = { + val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true) extractorFor(inputObject, TypeToken.of(beanClass)).asInstanceOf[CreateNamedStruct] } private def extractorFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = { def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = { - val externalType = inferExternalType(elementType.getRawType) val (dataType, nullable) = inferDataType(elementType) if (ScalaReflection.isNativeType(dataType)) { NewInstance( @@ -309,7 +325,7 @@ object JavaTypeInference { input :: Nil, dataType = ArrayType(dataType, nullable)) } else { - MapObjects(extractorFor(_, elementType), input, externalType) + MapObjects(extractorFor(_, elementType), input, ObjectType(elementType.getRawType)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 6ef9159f7d82..6c4721016a20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -66,12 +66,12 @@ object ExpressionEncoder { ClassTag[T](cls)) } + // TODO: improve error message for java bean encoder. def apply[T](beanClass: Class[T]): ExpressionEncoder[T] = { val schema = JavaTypeInference.inferDataType(beanClass)._1 assert(schema.isInstanceOf[StructType]) - val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true) - val toRowExpression = JavaTypeInference.extractorsFor(inputObject, beanClass) + val toRowExpression = JavaTypeInference.extractorsFor(beanClass) val fromRowExpression = JavaTypeInference.constructorFor(beanClass) new ExpressionEncoder[T]( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 501781cf7bea..a9e01fd71112 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -600,10 +600,13 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B override def dataType: DataType = ObjectType(tag.runtimeClass) } +/** + * Initialize a Java Bean instance by setting its field values via setters. + */ case class InitializeJavaBean(n: NewInstance, setters: Map[String, Expression]) extends Expression { - override def nullable: Boolean = false + override def nullable: Boolean = n.nullable override def children: Seq[Expression] = n +: setters.values.toSeq override def dataType: DataType = n.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala index 96588bb5dc1b..2b8cdc1e23ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.util +import scala.collection.JavaConverters._ + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.{DataType, Decimal} import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -24,6 +26,7 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class GenericArrayData(val array: Array[Any]) extends ArrayData { def this(seq: Seq[Any]) = this(seq.toArray) + def this(list: java.util.List[Any]) = this(list.asScala) // TODO: This is boxing. We should specialize. def this(primitiveArray: Array[Int]) = this(primitiveArray.toSeq) From fb9a898836820077b81baf21abf9f738eb716685 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 30 Nov 2015 20:02:37 +0800 Subject: [PATCH 4/6] fix test --- .../sql/catalyst/JavaTypeInference.scala | 45 ++++++++++--------- .../apache/spark/sql/JavaDataFrameSuite.java | 16 ------- 2 files changed, 23 insertions(+), 38 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 7c32690a3ed3..be349e671ba5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -101,18 +101,17 @@ object JavaTypeInference { val (valueDataType, nullable) = inferDataType(valueType) (MapType(keyDataType, valueDataType, nullable), true) - case other => - val properties = getJavaBeanProperties(other) - if (properties.length > 0) { - val fields = properties.map { property => - val returnType = typeToken.method(property.getReadMethod).getReturnType - val (dataType, nullable) = inferDataType(returnType) - new StructField(property.getName, dataType, nullable) - } - (new StructType(fields), true) - } else { - throw new UnsupportedOperationException(s"Cannot infer data type for ${other.getName}") + case _ => + // TODO: we should only collect properties that have getter and setter. However, some tests + // pass in scala case class as java bean class which doesn't have getter and setter. + val beanInfo = Introspector.getBeanInfo(typeToken.getRawType) + val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") + val fields = properties.map { property => + val returnType = typeToken.method(property.getReadMethod).getReturnType + val (dataType, nullable) = inferDataType(returnType) + new StructField(property.getName, dataType, nullable) } + (new StructType(fields), true) } } @@ -387,17 +386,19 @@ object JavaTypeInference { case other => val properties = getJavaBeanProperties(other) - assert(properties.length > 0) - - CreateNamedStruct(properties.flatMap { p => - val fieldName = p.getName - val fieldType = typeToken.method(p.getReadMethod).getReturnType - val fieldValue = Invoke( - inputObject, - p.getReadMethod.getName, - inferExternalType(fieldType.getRawType)) - expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil - }) + if (properties.length > 0) { + CreateNamedStruct(properties.flatMap { p => + val fieldName = p.getName + val fieldType = typeToken.method(p.getReadMethod).getReturnType + val fieldValue = Invoke( + inputObject, + p.getReadMethod.getName, + inferExternalType(fieldType.getRawType)) + expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil + }) + } else { + throw new UnsupportedOperationException(s"no encoder found for ${other.getName}") + } } } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 8027f2f72d28..8e0b2dbca4a9 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -136,33 +136,17 @@ public double getA() { return a; } - public void setA(double a) { - this.a = a; - } - public Integer[] getB() { return b; } - public void setB(Integer[] b) { - this.b = b; - } - public Map getC() { return c; } - public void setC(Map c) { - this.c = c; - } - public List getD() { return d; } - - public void setD(List d) { - this.d = d; - } } void validateDataFrameWithBeans(Bean bean, DataFrame df) { From 4abac6865cc3f6667c6053a384661e2674a26407 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 1 Dec 2015 11:45:46 +0800 Subject: [PATCH 5/6] address comments --- .../scala/org/apache/spark/sql/Encoder.scala | 11 +++++- .../sql/catalyst/JavaTypeInference.scala | 3 ++ .../catalyst/encoders/ExpressionEncoder.scala | 2 +- .../sql/catalyst/expressions/objects.scala | 20 +++++------ .../spark/sql/catalyst/trees/TreeNode.scala | 4 +-- .../sql/catalyst/trees/TreeNodeSuite.scala | 25 +++++++++++++ .../apache/spark/sql/JavaDatasetSuite.java | 36 +++++++++++++++---- 7 files changed, 80 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index f0c8dcee0557..c40061ae0aaf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -102,9 +102,18 @@ object Encoders { * * T must be publicly accessible. * + * supported types for java bean field: + * - primitive types: boolean, int, double, etc. + * - boxed types: Boolean, Integer, Double, etc. + * - String + * - java.math.BigDecimal + * - time related: java.sql.Date, java.sql.Timestamp + * - collection types: only array and java.util.List currently, map support is in progress + * - nested java bean. + * * @since 1.6.0 */ - def bean[T](beanClass: Class[T]): Encoder[T] = ExpressionEncoder(beanClass) + def bean[T](beanClass: Class[T]): Encoder[T] = ExpressionEncoder.javaBean(beanClass) /** * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index be349e671ba5..c8ee87e8819f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -382,6 +382,9 @@ object JavaTypeInference { toCatalystArray(inputObject, elementType(typeToken)) case _ if mapType.isAssignableFrom(typeToken) => + // TODO: for java map, if we get the keys and values by `keySet` and `values`, we can + // not guarantee they have same iteration order(which is different from scala map). + // A possible solution is creating a new `MapObjects` that can iterate a map directly. throw new UnsupportedOperationException("map type is not supported currently") case other => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 6c4721016a20..0bdc32580a35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -67,7 +67,7 @@ object ExpressionEncoder { } // TODO: improve error message for java bean encoder. - def apply[T](beanClass: Class[T]): ExpressionEncoder[T] = { + def javaBean[T](beanClass: Class[T]): ExpressionEncoder[T] = { val schema = JavaTypeInference.inferDataType(beanClass)._1 assert(schema.isInstanceOf[StructType]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index a9e01fd71112..e6ab9a31be59 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -603,34 +603,34 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B /** * Initialize a Java Bean instance by setting its field values via setters. */ -case class InitializeJavaBean(n: NewInstance, setters: Map[String, Expression]) +case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Expression]) extends Expression { - override def nullable: Boolean = n.nullable - override def children: Seq[Expression] = n +: setters.values.toSeq - override def dataType: DataType = n.dataType + override def nullable: Boolean = beanInstance.nullable + override def children: Seq[Expression] = beanInstance +: setters.values.toSeq + override def dataType: DataType = beanInstance.dataType override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val instance = n.gen(ctx) + val instanceGen = beanInstance.gen(ctx) val initialize = setters.map { case (setterMethod, fieldValue) => val fieldGen = fieldValue.gen(ctx) s""" ${fieldGen.code} - ${instance.value}.$setterMethod(${fieldGen.value}); + ${instanceGen.value}.$setterMethod(${fieldGen.value}); """ } - ev.isNull = instance.isNull - ev.value = instance.value + ev.isNull = instanceGen.isNull + ev.value = instanceGen.value s""" - ${instance.code} - if (!${instance.isNull}) { + ${instanceGen.code} + if (!${instanceGen.isNull}) { ${initialize.mkString("\n")} } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 0650611c3916..f1cea07976a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -205,7 +205,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } case nonChild: AnyRef => nonChild case null => null - }.view.force + }.view.force // `mapValues` is lazy and we need to force it to materialize case arg: TreeNode[_] if containsChild(arg) => val newChild = remainingNewChildren.remove(0) val oldChild = remainingOldChildren.remove(0) @@ -287,7 +287,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { arg } case other => other - }.view.force + }.view.force // `mapValues` is lazy and we need to force it to materialize case d: DataType => d // Avoid unpacking Structs case args: Traversable[_] => args.map { case arg: TreeNode[_] if containsChild(arg) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 8fff39906b34..965bdb1515e5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -38,6 +38,13 @@ case class ComplexPlan(exprs: Seq[Seq[Expression]]) override def output: Seq[Attribute] = Nil } +case class ExpressionInMap(map: Map[String, Expression]) extends Expression with Unevaluable { + override def children: Seq[Expression] = map.values.toSeq + override def nullable: Boolean = true + override def dataType: NullType = NullType + override lazy val resolved = true +} + class TreeNodeSuite extends SparkFunSuite { test("top node changed") { val after = Literal(1) transform { case Literal(1, _) => Literal(2) } @@ -236,4 +243,22 @@ class TreeNodeSuite extends SparkFunSuite { val expected = ComplexPlan(Seq(Seq(Literal("1")), Seq(Literal("2")))) assert(expected === actual) } + + test("expressions inside a map") { + val expression = ExpressionInMap(Map("1" -> Literal(1), "2" -> Literal(2))) + + { + val actual = expression.transform { + case Literal(i: Int, _) => Literal(i + 1) + } + val expected = ExpressionInMap(Map("1" -> Literal(2), "2" -> Literal(3))) + assert(actual === expected) + } + + { + val actual = expression.withNewChildren(Seq(Literal(2), Literal(3))) + val expected = ExpressionInMap(Map("1" -> Literal(2), "2" -> Literal(3))) + assert(actual === expected) + } + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 438fa4decce0..83e1db90e306 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -31,15 +31,15 @@ import org.apache.spark.SparkContext; import org.apache.spark.api.java.function.*; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.Encoder; -import org.apache.spark.sql.Encoders; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.GroupedDataset; +import org.apache.spark.sql.*; import org.apache.spark.sql.expressions.Aggregator; import org.apache.spark.sql.test.TestSQLContext; import org.apache.spark.sql.catalyst.encoders.OuterScopes; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.types.StructType; import static org.apache.spark.sql.functions.*; +import static org.apache.spark.sql.types.DataTypes.*; public class JavaDatasetSuite implements Serializable { private transient JavaSparkContext jsc; @@ -613,14 +613,14 @@ public void testJavaBeanEncoder() { SimpleJavaBean obj1 = new SimpleJavaBean(); obj1.setA(true); obj1.setB(3); - obj1.setC(new byte[]{1}); + obj1.setC(new byte[]{1, 2}); obj1.setD(new String[]{"hello"}); obj1.setE(Arrays.asList("a", "b")); SimpleJavaBean obj2 = new SimpleJavaBean(); obj2.setA(false); obj2.setB(30); - obj2.setC(new byte[]{2}); - obj1.setD(new String[]{"world"}); + obj2.setC(new byte[]{3, 4}); + obj2.setD(new String[]{"world"}); obj2.setE(Arrays.asList("x", "y")); List data = Arrays.asList(obj1, obj2); @@ -633,5 +633,27 @@ public void testJavaBeanEncoder() { List data2 = Arrays.asList(obj3); Dataset ds2 = context.createDataset(data2, Encoders.bean(NestedJavaBean.class)); Assert.assertEquals(data2, ds2.collectAsList()); + + Row row1 = new GenericRow(new Object[]{ + true, + 3, + new byte[]{1, 2}, + new String[]{"hello"}, + Arrays.asList("a", "b")}); + Row row2 = new GenericRow(new Object[]{ + false, + 30, + new byte[]{3, 4}, + new String[]{"world"}, + Arrays.asList("x", "y")}); + StructType schema = new StructType() + .add("a", BooleanType, false) + .add("b", IntegerType, false) + .add("c", BinaryType) + .add("d", createArrayType(StringType)) + .add("e", createArrayType(StringType)); + Dataset ds3 = context.createDataFrame(Arrays.asList(row1, row2), schema) + .as(Encoders.bean(SimpleJavaBean.class)); + Assert.assertEquals(data, ds3.collectAsList()); } } From 0d95daf74cad5ff478d7110dcad1f674971e6f1a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 1 Dec 2015 11:54:05 +0800 Subject: [PATCH 6/6] improve test --- .../apache/spark/sql/JavaDatasetSuite.java | 32 ++++++++++++++----- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 83e1db90e306..ae47f4fe0e23 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -514,6 +514,7 @@ public class SimpleJavaBean implements Serializable { private byte[] c; private String[] d; private List e; + private List f; public boolean isA() { return a; @@ -555,6 +556,14 @@ public void setE(List e) { this.e = e; } + public List getF() { + return f; + } + + public void setF(List f) { + this.f = f; + } + @Override public boolean equals(Object o) { if (this == o) return true; @@ -566,7 +575,8 @@ public boolean equals(Object o) { if (b != that.b) return false; if (!Arrays.equals(c, that.c)) return false; if (!Arrays.equals(d, that.d)) return false; - return e.equals(that.e); + if (!e.equals(that.e)) return false; + return f.equals(that.f); } @Override @@ -576,6 +586,7 @@ public int hashCode() { result = 31 * result + Arrays.hashCode(c); result = 31 * result + Arrays.hashCode(d); result = 31 * result + e.hashCode(); + result = 31 * result + f.hashCode(); return result; } } @@ -614,14 +625,16 @@ public void testJavaBeanEncoder() { obj1.setA(true); obj1.setB(3); obj1.setC(new byte[]{1, 2}); - obj1.setD(new String[]{"hello"}); + obj1.setD(new String[]{"hello", null}); obj1.setE(Arrays.asList("a", "b")); + obj1.setF(Arrays.asList(100L, null, 200L)); SimpleJavaBean obj2 = new SimpleJavaBean(); obj2.setA(false); obj2.setB(30); obj2.setC(new byte[]{3, 4}); - obj2.setD(new String[]{"world"}); + obj2.setD(new String[]{null, "world"}); obj2.setE(Arrays.asList("x", "y")); + obj2.setF(Arrays.asList(300L, null, 400L)); List data = Arrays.asList(obj1, obj2); Dataset ds = context.createDataset(data, Encoders.bean(SimpleJavaBean.class)); @@ -638,20 +651,23 @@ public void testJavaBeanEncoder() { true, 3, new byte[]{1, 2}, - new String[]{"hello"}, - Arrays.asList("a", "b")}); + new String[]{"hello", null}, + Arrays.asList("a", "b"), + Arrays.asList(100L, null, 200L)}); Row row2 = new GenericRow(new Object[]{ false, 30, new byte[]{3, 4}, - new String[]{"world"}, - Arrays.asList("x", "y")}); + new String[]{null, "world"}, + Arrays.asList("x", "y"), + Arrays.asList(300L, null, 400L)}); StructType schema = new StructType() .add("a", BooleanType, false) .add("b", IntegerType, false) .add("c", BinaryType) .add("d", createArrayType(StringType)) - .add("e", createArrayType(StringType)); + .add("e", createArrayType(StringType)) + .add("f", createArrayType(LongType)); Dataset ds3 = context.createDataFrame(Arrays.asList(row1, row2), schema) .as(Encoders.bean(SimpleJavaBean.class)); Assert.assertEquals(data, ds3.collectAsList());