Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,10 @@ object ScalaReflection extends ScalaReflection {
Invoke(arrayData, primitiveMethod, arrayCls, returnNullable = false)
}

case t if t <:< localTypeOf[Seq[_]] =>
// We serialize a `Set` to Catalyst array. When we deserialize a Catalyst array
// to a `Set`, if there are duplicated elements, the elements will be de-duplicated.
case t if t <:< localTypeOf[Seq[_]] ||
t <:< localTypeOf[scala.collection.Set[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, elementNullable) = schemaFor(elementType)
val className = getClassNameFromType(elementType)
Expand All @@ -325,8 +328,10 @@ object ScalaReflection extends ScalaReflection {
}

val companion = t.normalize.typeSymbol.companionSymbol.typeSignature
val cls = companion.declaration(newTermName("newBuilder")) match {
case NoSymbol => classOf[Seq[_]]
val cls = companion.member(newTermName("newBuilder")) match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it work in scala 2.10?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works.

To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
Welcome to
      ____              __
     / __/__  ___ _____/ /__
    _\ \/ _ \/ _ `/ __/  '_/
   /___/ .__/\_,_/_/ /_/\_\   version 2.3.0-SNAPSHOT
      /_/

Using Scala version 2.10.6 (Java HotSpot(TM) 64-Bit Server VM, Java 1.8.0_102)
Type in expressions to have them evaluated.
Type :help for more information.
...

scala> val t = ScalaReflection.localTypeOf[scala.collection.mutable.HashSet[Int]]
t: org.apache.spark.sql.catalyst.ScalaReflection.universe.Type = scala.collection.mutable.HashSet[Int]

scala> val companion = t.normalize.typeSymbol.companionSymbol.typeSignature
companion: org.apache.spark.sql.catalyst.ScalaReflection.universe.Type = scala.collection.mutable.HashSet.type

scala> val cls = companion.member(newTermName("newBuilder"))
cls: org.apache.spark.sql.catalyst.ScalaReflection.universe.Symbol = method newBuilder

case NoSymbol if t <:< localTypeOf[Seq[_]] => classOf[Seq[_]]
case NoSymbol if t <:< localTypeOf[scala.collection.Set[_]] =>
classOf[scala.collection.Set[_]]
case _ => mirror.runtimeClass(t.typeSymbol.asClass)
}
UnresolvedMapObjects(mapFunction, getPath, Some(cls))
Expand Down Expand Up @@ -498,6 +503,19 @@ object ScalaReflection extends ScalaReflection {
serializerFor(_, valueType, valuePath, seenTypeSet),
valueNullable = !valueType.typeSymbol.asClass.isPrimitive)

case t if t <:< localTypeOf[scala.collection.Set[_]] =>
val TypeRef(_, _, Seq(elementType)) = t

// There's no corresponding Catalyst type for `Set`, we serialize a `Set` to Catalyst array.
// Note that the property of `Set` is only kept when manipulating the data as domain object.
val newInput =
Invoke(
inputObject,
"toSeq",
ObjectType(classOf[Seq[_]]))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For primitive, calling toArray can directly construct UnsafeArrayData in toCatalystArray. However, toArray requires a ClassTag. To generate it with a StaticInvoke might be hacky. So for now I simply use toSeq.


toCatalystArray(newInput, elementType)

case t if t <:< localTypeOf[String] =>
StaticInvoke(
classOf[UTF8String],
Expand Down Expand Up @@ -702,6 +720,10 @@ object ScalaReflection extends ScalaReflection {
val Schema(valueDataType, valueNullable) = schemaFor(valueType)
Schema(MapType(schemaFor(keyType).dataType,
valueDataType, valueContainsNull = valueNullable), nullable = true)
case t if t <:< localTypeOf[Set[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true)
case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< localTypeOf[java.sql.Date] => Schema(DateType, nullable = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -598,8 +598,9 @@ case class MapObjects private(

val (initCollection, addElement, getResult): (String, String => String, String) =
customCollectionCls match {
case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
// Scala sequence
case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) ||
classOf[scala.collection.Set[_]].isAssignableFrom(cls) =>
// Scala sequence or set
val getBuilder = s"${cls.getName}$$.MODULE$$.newBuilder()"
val builder = ctx.freshName("collectionBuilder")
(
Expand Down
10 changes: 10 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,16 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits {
/** @since 2.3.0 */
implicit def newMapEncoder[T <: Map[_, _] : TypeTag]: Encoder[T] = ExpressionEncoder()

/**
* Notice that we serialize `Set` to Catalyst array. The set property is only kept when
* manipulating the domain objects. The serialization format doesn't keep the set property.
* When we have a Catalyst array which contains duplicated elements and convert it to
* `Dataset[Set[T]]` by using the encoder, the elements will be de-duplicated.
*
* @since 2.3.0
*/
implicit def newSetEncoder[T <: Set[_] : TypeTag]: Encoder[T] = ExpressionEncoder()

// Arrays

/** @since 1.6.1 */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,16 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
df.select(collect_set($"a"), collect_set($"b")),
Seq(Row(Seq(1, 2, 3), Seq(2, 4)))
)

checkDataset(
df.select(collect_set($"a").as("aSet")).as[Set[Int]],
Set(1, 2, 3))
checkDataset(
df.select(collect_set($"b").as("bSet")).as[Set[Int]],
Set(2, 4))
checkDataset(
df.select(collect_set($"a"), collect_set($"b")).as[(Set[Int], Set[Int])],
Seq(Set(1, 2, 3) -> Set(2, 4)): _*)
}

test("collect functions structs") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql

import scala.collection.immutable.{HashSet => HSet}
import scala.collection.immutable.Queue
import scala.collection.mutable.{LinkedHashMap => LHMap}
import scala.collection.mutable.ArrayBuffer
Expand Down Expand Up @@ -339,6 +340,31 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
LHMapClass(LHMap(1 -> 2)) -> LHMap("test" -> MapClass(Map(3 -> 4))))
}

test("arbitrary sets") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to test null cases?

Seq(Seq(Some(1), None), Seq(Some(2))).toDF("c").as[Set[Int]]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test for it.

checkDataset(Seq(Set(1, 2, 3, 4)).toDS(), Set(1, 2, 3, 4))
checkDataset(Seq(Set(1.toLong, 2.toLong)).toDS(), Set(1.toLong, 2.toLong))
checkDataset(Seq(Set(1.toDouble, 2.toDouble)).toDS(), Set(1.toDouble, 2.toDouble))
checkDataset(Seq(Set(1.toFloat, 2.toFloat)).toDS(), Set(1.toFloat, 2.toFloat))
checkDataset(Seq(Set(1.toByte, 2.toByte)).toDS(), Set(1.toByte, 2.toByte))
checkDataset(Seq(Set(1.toShort, 2.toShort)).toDS(), Set(1.toShort, 2.toShort))
checkDataset(Seq(Set(true, false)).toDS(), Set(true, false))
checkDataset(Seq(Set("test1", "test2")).toDS(), Set("test1", "test2"))
checkDataset(Seq(Set(Tuple1(1), Tuple1(2))).toDS(), Set(Tuple1(1), Tuple1(2)))

checkDataset(Seq(HSet(1, 2)).toDS(), HSet(1, 2))
checkDataset(Seq(HSet(1.toLong, 2.toLong)).toDS(), HSet(1.toLong, 2.toLong))
checkDataset(Seq(HSet(1.toDouble, 2.toDouble)).toDS(), HSet(1.toDouble, 2.toDouble))
checkDataset(Seq(HSet(1.toFloat, 2.toFloat)).toDS(), HSet(1.toFloat, 2.toFloat))
checkDataset(Seq(HSet(1.toByte, 2.toByte)).toDS(), HSet(1.toByte, 2.toByte))
checkDataset(Seq(HSet(1.toShort, 2.toShort)).toDS(), HSet(1.toShort, 2.toShort))
checkDataset(Seq(HSet(true, false)).toDS(), HSet(true, false))
checkDataset(Seq(HSet("test1", "test2")).toDS(), HSet("test1", "test2"))
checkDataset(Seq(HSet(Tuple1(1), Tuple1(2))).toDS(), HSet(Tuple1(1), Tuple1(2)))

checkDataset(Seq(Seq(Some(1), None), Seq(Some(2))).toDF("c").as[Set[Integer]],
Seq(Set[Integer](1, null), Set[Integer](2)): _*)
}

test("nested sequences") {
checkDataset(Seq(Seq(Seq(1))).toDS(), Seq(Seq(1)))
checkDataset(Seq(List(Queue(1))).toDS(), List(Queue(1)))
Expand All @@ -349,6 +375,11 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
checkDataset(Seq(LHMap(Map(1 -> 2) -> 3)).toDS(), LHMap(Map(1 -> 2) -> 3))
}

test("nested set") {
checkDataset(Seq(Set(HSet(1, 2), HSet(3, 4))).toDS(), Set(HSet(1, 2), HSet(3, 4)))
checkDataset(Seq(HSet(Set(1, 2), Set(3, 4))).toDS(), HSet(Set(1, 2), Set(3, 4)))
}

test("package objects") {
import packageobject._
checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1))
Expand Down