Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -312,12 +312,46 @@ object ScalaReflection extends ScalaReflection {
"array",
ObjectType(classOf[Array[Any]]))

StaticInvoke(
val wrappedArray = StaticInvoke(
scala.collection.mutable.WrappedArray.getClass,
ObjectType(classOf[Seq[_]]),
"make",
array :: Nil)

if (localTypeOf[scala.collection.mutable.WrappedArray[_]] <:< t.erasure) {
wrappedArray
} else {
// Convert to another type using `to`
val cls = mirror.runtimeClass(t.typeSymbol.asClass)
import scala.collection.generic.CanBuildFrom
import scala.reflect.ClassTag
import scala.util.{Try, Success}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spark code style discourage the usage of Try and Success, can you refactor your code a little bit? i.e. move cls.getDeclaredMethod("canBuildFrom", classOf[ClassTag[_]]) out of the Invoke code block.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I tried looking up the code style you mentioned, but only found the Databricks' Scala Code Style Guide. And that is not mentioned in the Spark docs as far as I know.

Invoke(
wrappedArray,
"to",
ObjectType(cls),
StaticInvoke(
cls,
ObjectType(classOf[CanBuildFrom[_, _, _]]),
"canBuildFrom",
Try(cls.getDeclaredMethod("canBuildFrom", classOf[ClassTag[_]])) match {
case Success(_) =>
StaticInvoke(
ClassTag.getClass,
ObjectType(classOf[ClassTag[_]]),
"apply",
StaticInvoke(
cls,
ObjectType(classOf[Class[_]]),
"getClass"
) :: Nil
) :: Nil
case _ => Nil
}
) :: Nil
)
}

case t if t <:< localTypeOf[Map[_, _]] =>
// TODO: add walked type path for map
val TypeRef(_, _, Seq(keyType, valueType)) = t
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,37 @@ class ScalaReflectionSuite extends SparkFunSuite {
.cls.isAssignableFrom(classOf[org.apache.spark.sql.catalyst.util.GenericArrayData]))
}

test("SPARK 16792: Get correct deserializer for List[_]") {
val listDeserializer = deserializerFor[List[Int]]
assert(listDeserializer.dataType == ObjectType(classOf[List[_]]))
}

test("serialize and deserialize arbitrary sequence types") {
import scala.collection.immutable.Queue
val queueSerializer = serializerFor[Queue[Int]](BoundReference(
0, ObjectType(classOf[Queue[Int]]), nullable = false))
assert(queueSerializer.dataType.head.dataType ==
ArrayType(IntegerType, containsNull = false))
val queueDeserializer = deserializerFor[Queue[Int]]
assert(queueDeserializer.dataType == ObjectType(classOf[Queue[_]]))

import scala.collection.mutable.ArrayBuffer
val arrayBufferSerializer = serializerFor[ArrayBuffer[Int]](BoundReference(
0, ObjectType(classOf[ArrayBuffer[Int]]), nullable = false))
assert(arrayBufferSerializer.dataType.head.dataType ==
ArrayType(IntegerType, containsNull = false))
val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]]
assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]]))

// Check whether conversion is skipped when using WrappedArray[_] supertype
// (would otherwise needlessly add overhead)
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
val seqDeserializer = deserializerFor[Seq[Int]]
assert(seqDeserializer.asInstanceOf[StaticInvoke].staticObject ==
scala.collection.mutable.WrappedArray.getClass)
assert(seqDeserializer.asInstanceOf[StaticInvoke].functionName == "make")
}

private val dataTypeForComplexData = dataTypeFor[ComplexData]
private val typeOfComplexData = typeOf[ComplexData]

Expand Down
63 changes: 54 additions & 9 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -100,31 +100,76 @@ abstract class SQLImplicits {
// Seqs

/** @since 1.6.1 */
implicit def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
implicit def newIntSeqEncoder[T <: Seq[Int] : TypeTag]: Encoder[T] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder()
implicit def newLongSeqEncoder[T <: Seq[Long] : TypeTag]: Encoder[T] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder()
implicit def newDoubleSeqEncoder[T <: Seq[Double] : TypeTag]: Encoder[T] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder()
implicit def newFloatSeqEncoder[T <: Seq[Float] : TypeTag]: Encoder[T] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder()
implicit def newByteSeqEncoder[T <: Seq[Byte] : TypeTag]: Encoder[T] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder()
implicit def newShortSeqEncoder[T <: Seq[Short] : TypeTag]: Encoder[T] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder()
implicit def newBooleanSeqEncoder[T <: Seq[Boolean] : TypeTag]: Encoder[T] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder()
implicit def newStringSeqEncoder[T <: Seq[String] : TypeTag]: Encoder[T] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newProductSeqEncoder[A <: Product : TypeTag]: Encoder[Seq[A]] = ExpressionEncoder()
implicit def newProductSeqEncoder[A <: Product : TypeTag, T <: Seq[A] : TypeTag]: Encoder[T] =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is my only concern now. Can you provide more details about it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is the same as all the other ones, just with Product subclasses. If you were concerned about the TypeTag on A, it was actually not needed as T's tag already contains all the information. I just tested it to be sure and removed it.

ExpressionEncoder()

// Seqs with product (List) disambiguation

/** @since 2.2.0 */
implicit def newIntSeqWithProductEncoder[T <: Seq[Int] with Product : TypeTag]: Encoder[T] =
newIntSeqEncoder

/** @since 2.2.0 */
implicit def newLongSeqWithProductEncoder[T <: Seq[Long] with Product : TypeTag]: Encoder[T] =
newLongSeqEncoder

/** @since 2.2.0 */
implicit def newDoubleListEncoder[T <: Seq[Double] with Product : TypeTag]: Encoder[T] =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be newDoubleSeqWithProductEncoder?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it should. Thanks

newDoubleSeqEncoder

/** @since 2.2.0 */
implicit def newFloatSeqWithProductEncoder[T <: Seq[Float] with Product : TypeTag]: Encoder[T] =
newFloatSeqEncoder

/** @since 2.2.0 */
implicit def newByteSeqWithProductEncoder[T <: Seq[Byte] with Product : TypeTag]: Encoder[T] =
newByteSeqEncoder

/** @since 2.2.0 */
implicit def newShortSeqWithProductEncoder[T <: Seq[Short] with Product : TypeTag]: Encoder[T] =
newShortSeqEncoder

/** @since 2.2.0 */
implicit def newBooleanSeqWithProductEncoder[T <: Seq[Boolean] with Product : TypeTag]
: Encoder[T] =
newBooleanSeqEncoder

/** @since 2.2.0 */
implicit def newStringSeqWithProductEncoder[T <: Seq[String] with Product : TypeTag]: Encoder[T] =
newStringSeqEncoder

/** @since 2.2.0 */
implicit def newProductSeqWithProductEncoder
[A <: Product : TypeTag, T <: Seq[A] with Product : TypeTag]: Encoder[T] =
newProductSeqEncoder[A, T]

// Workaround for implicit resolution problem for Seq.toDS (only supports Seq)
implicit def newProductSeqOnlyEncoder[A <: Product : TypeTag]: Encoder[Seq[A]] =
newProductSeqEncoder[A, Seq[A]]

// Arrays

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,34 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
checkDataset(Seq(Array(Tuple1(1))).toDS(), Array(Tuple1(1)))
}

test("arbitrary sequences") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's also test nested sequences, e.g. List(Queue(1)), and sequences inside product, e.g. List(1) -> Queue(1)

Copy link
Contributor Author

@michalsenkyr michalsenkyr Jan 3, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some sequence-product combination tests.
Nested sequences were never supported (tried on master and 2.0.2). That would probably be worthy of another ticket.

import scala.collection.immutable.Queue
checkDataset(Seq(Queue(1)).toDS(), Queue(1))
checkDataset(Seq(Queue(1.toLong)).toDS(), Queue(1.toLong))
checkDataset(Seq(Queue(1.toDouble)).toDS(), Queue(1.toDouble))
checkDataset(Seq(Queue(1.toFloat)).toDS(), Queue(1.toFloat))
checkDataset(Seq(Queue(1.toByte)).toDS(), Queue(1.toByte))
checkDataset(Seq(Queue(1.toShort)).toDS(), Queue(1.toShort))
checkDataset(Seq(Queue(true)).toDS(), Queue(true))
checkDataset(Seq(Queue("test")).toDS(), Queue("test"))
// Implicit resolution problem - encoder needs to be provided explicitly
implicit val queueEncoder = newProductSeqEncoder[Tuple1[Int], Queue[Tuple1[Int]]]
checkDataset(Seq(Queue(Tuple1(1))).toDS(), Queue(Tuple1(1)))

import scala.collection.mutable.ArrayBuffer
checkDataset(Seq(ArrayBuffer(1)).toDS(), ArrayBuffer(1))
checkDataset(Seq(ArrayBuffer(1.toLong)).toDS(), ArrayBuffer(1.toLong))
checkDataset(Seq(ArrayBuffer(1.toDouble)).toDS(), ArrayBuffer(1.toDouble))
checkDataset(Seq(ArrayBuffer(1.toFloat)).toDS(), ArrayBuffer(1.toFloat))
checkDataset(Seq(ArrayBuffer(1.toByte)).toDS(), ArrayBuffer(1.toByte))
checkDataset(Seq(ArrayBuffer(1.toShort)).toDS(), ArrayBuffer(1.toShort))
checkDataset(Seq(ArrayBuffer(true)).toDS(), ArrayBuffer(true))
checkDataset(Seq(ArrayBuffer("test")).toDS(), ArrayBuffer("test"))
// Implicit resolution problem - encoder needs to be provided explicitly
implicit val arrayBufferEncoder = newProductSeqEncoder[Tuple1[Int], ArrayBuffer[Tuple1[Int]]]
checkDataset(Seq(ArrayBuffer(Tuple1(1))).toDS(), ArrayBuffer(Tuple1(1)))
}

test("package objects") {
import packageobject._
checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1))
Expand Down