Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 45 additions & 17 deletions core/src/main/scala/ste/encoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,26 @@ import scala.annotation.StaticAnnotation

import scala.collection.generic.IsTraversableOnce

final class Meta(val metadata: Metadata) extends StaticAnnotation
final case class Meta(metadata: Metadata) extends StaticAnnotation
final case class Flatten(times: Int = 1, keys: Seq[String] = Seq()) extends StaticAnnotation

@annotation.implicitNotFound("""
Type ${A} does not have a DataTypeEncoder defined in the library.
You need to define one yourself.
""")
sealed trait DataTypeEncoder[A] {
def encode: DataType
def fields: Option[Seq[StructField]]
def nullable: Boolean
}

object DataTypeEncoder {
def apply[A](implicit enc: DataTypeEncoder[A]): DataTypeEncoder[A] = enc

def pure[A](dt: DataType, isNullable: Boolean = false): DataTypeEncoder[A] =
def pure[A](dt: DataType, f: Option[Seq[StructField]] = None, isNullable: Boolean = false): DataTypeEncoder[A] =
new DataTypeEncoder[A] {
def encode: DataType = dt
def fields: Option[Seq[StructField]] = f
def nullable: Boolean = isNullable
}
}
Expand All @@ -56,6 +59,7 @@ object DataTypeEncoder {
""")
sealed trait StructTypeEncoder[A] extends DataTypeEncoder[A] {
def encode: StructType
def fields: Option[Seq[StructField]]
def nullable: Boolean
}

Expand All @@ -65,6 +69,7 @@ object StructTypeEncoder extends MediumPriorityImplicits {
def pure[A](st: StructType, isNullable: Boolean = false): StructTypeEncoder[A] =
new StructTypeEncoder[A] {
def encode: StructType = st
def fields: Option[Seq[StructField]] = Some(st.fields.toSeq)
def nullable: Boolean = isNullable
}
}
Expand All @@ -80,7 +85,7 @@ sealed trait AnnotatedStructTypeEncoder[A] {
}

object AnnotatedStructTypeEncoder extends MediumPriorityImplicits {
type Encode = Seq[Metadata] => StructType
type Encode = (Seq[Metadata], Seq[Option[Flatten]]) => StructType

def pure[A](enc: Encode): AnnotatedStructTypeEncoder[A] =
new AnnotatedStructTypeEncoder[A] {
Expand All @@ -89,29 +94,46 @@ object AnnotatedStructTypeEncoder extends MediumPriorityImplicits {
}

trait LowPriorityImplicits {
implicit val hnilEncoder: AnnotatedStructTypeEncoder[HNil] = AnnotatedStructTypeEncoder.pure(_ => StructType(Nil))
implicit val hnilEncoder: AnnotatedStructTypeEncoder[HNil] =
AnnotatedStructTypeEncoder.pure((_, _) => StructType(Nil))
implicit def hconsEncoder[K <: Symbol, H, T <: HList](
implicit
witness: Witness.Aux[K],
hEncoder: Lazy[DataTypeEncoder[H]],
tEncoder: AnnotatedStructTypeEncoder[T]
): AnnotatedStructTypeEncoder[FieldType[K, H] :: T] = AnnotatedStructTypeEncoder.pure { metadata =>
): AnnotatedStructTypeEncoder[FieldType[K, H] :: T] = AnnotatedStructTypeEncoder.pure { (metadata, flatten) =>
val fieldName = witness.value.name
val head = hEncoder.value.encode
val nullable = hEncoder.value.nullable
val tail = tEncoder.encode(metadata.tail)
StructType(StructField(fieldName, head, nullable, metadata.head) +: tail.fields)
val dt = hEncoder.value.encode
val fields = flatten.head.flatMap(f => hEncoder.value.fields.map(flattenFields(_, dt, fieldName, f))).getOrElse(
Seq(StructField(fieldName, dt, hEncoder.value.nullable, metadata.head)))
val tail = tEncoder.encode(metadata.tail, flatten.tail)
StructType(fields ++ tail.fields)
}

implicit def recordEncoder[A, H <: HList, HA <: HList](
private def flattenFields(fields: Seq[StructField], dt: DataType, prefix: String, flatten: Flatten) =
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we specify the return type?

(dt, flatten) match {
case (_: ArrayType, Flatten(times, _)) if times > 1 =>
(0 until times).flatMap(i => fields.map(prefixStructField(_, s"$prefix.$i")))
case (_: MapType, Flatten(_, keys)) if keys.nonEmpty =>
keys.flatMap(k => fields.map(prefixStructField(_, s"$prefix.$k")))
case (_, Flatten(_, _)) => fields.map(prefixStructField(_, prefix))
}

private def prefixStructField(f: StructField, prefix: String) =
f.copy(name = s"$prefix.${f.name}")

implicit def recordEncoder[A, H <: HList, HA <: HList, HF <: HList](
implicit
generic: LabelledGeneric.Aux[A, H],
annotations: Annotations.Aux[Meta, A, HA],
metaAnnotations: Annotations.Aux[Meta, A, HA],
flattenAnnotations: Annotations.Aux[Flatten, A, HF],
hEncoder: Lazy[AnnotatedStructTypeEncoder[H]],
toList: ToList[HA, Option[Meta]]
metaToList: ToList[HA, Option[Meta]],
flattenToList: ToList[HF, Option[Flatten]]
): StructTypeEncoder[A] = {
val metadata = annotations().toList[Option[Meta]].map(extractMetadata)
StructTypeEncoder.pure(hEncoder.value.encode(metadata))
val metadata = metaAnnotations().toList[Option[Meta]].map(extractMetadata)
val flatten = flattenAnnotations().toList[Option[Flatten]]
StructTypeEncoder.pure(hEncoder.value.encode(metadata, flatten))
}

private val extractMetadata: Option[Meta] => Metadata =
Expand Down Expand Up @@ -153,16 +175,22 @@ trait MediumPriorityImplicits extends LowPriorityImplicits {
enc: DataTypeEncoder[A0],
is: IsTraversableOnce[C[A0]] { type A = A0 }
): DataTypeEncoder[C[A0]] =
DataTypeEncoder.pure(ArrayType(enc.encode))
DataTypeEncoder.pure(ArrayType(enc.encode), enc.fields)
implicit def mapEncoder[K, V](
implicit
kEnc: DataTypeEncoder[K],
vEnc: DataTypeEncoder[V]
): DataTypeEncoder[collection.Map[K, V]] =
DataTypeEncoder.pure(MapType(kEnc.encode, vEnc.encode), vEnc.fields)
implicit def immutableMapEncoder[K, V](
implicit
kEnc: DataTypeEncoder[K],
vEnc: DataTypeEncoder[V]
): DataTypeEncoder[Map[K, V]] =
DataTypeEncoder.pure(MapType(kEnc.encode, vEnc.encode))
DataTypeEncoder.pure(MapType(kEnc.encode, vEnc.encode), vEnc.fields)
implicit def optionEncoder[V](
implicit
enc: DataTypeEncoder[V]
): DataTypeEncoder[Option[V]] =
DataTypeEncoder.pure(enc.encode, true)
DataTypeEncoder.pure(enc.encode, isNullable = true)
}
197 changes: 197 additions & 0 deletions core/src/main/scala/ste/selector.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
package ste
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we add the license header in this file and the associated spec?


import org.apache.spark.sql.{ Column, DataFrame, Dataset, Encoder }
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import scala.annotation.tailrec
import scala.collection.generic.IsTraversableOnce
import scala.collection.breakOut
import shapeless._
import shapeless.ops.hlist._
import shapeless.syntax.std.tuple._
import shapeless.labelled.FieldType

case class Prefix(p: String) {
def addSuffix(s: Any) = Prefix(s"$p.$s")
def getParent = Prefix(p.split("\\.").dropRight(1).mkString("."))
def getSuffix = p.split("\\.").last
def isParentOf(other: Prefix) = other.toString.startsWith(s"$p.")
def isChildrenOf(other: Prefix) = other.isParentOf(this)
def quotedString = s"`$p`"
override def toString = p
}

@annotation.implicitNotFound("""
Type ${A} does not have a DataTypeSelector defined in the library.
You need to define one yourself.
""")
sealed trait DataTypeSelector[A] {
import DataTypeSelector.Select

val select: Select
}

object DataTypeSelector {
type Prefixes = Seq[Prefix]
type Select = (DataFrame, Option[Prefixes]) => DataFrame

def pure[A](s: Select): DataTypeSelector[A] =
new DataTypeSelector[A] {
val select: Select = s
}

def identityDF[A]: DataTypeSelector[A] =
new DataTypeSelector[A] {
val select: Select = (df, _) => df
}
}

@annotation.implicitNotFound("""
Type ${A} does not have a StructTypeSelector defined in the library.
You need to define one yourself.
""")
sealed trait StructTypeSelector[A] extends DataTypeSelector[A] {
import DataTypeSelector.Select

val select: Select
}

object StructTypeSelector extends SelectorImplicits {
import DataTypeSelector.Select

def apply[A](implicit s: StructTypeSelector[A]): StructTypeSelector[A] = s

def pure[A](s: Select): StructTypeSelector[A] =
new StructTypeSelector[A] {
val select: Select = s
}
}

@annotation.implicitNotFound("""
Type ${A} does not have a AnnotatedStructTypeSelector defined in the library.
You need to define one yourself.
""")
sealed trait AnnotatedStructTypeSelector[A] {
import AnnotatedStructTypeSelector.Select

val select: Select
}

object AnnotatedStructTypeSelector extends SelectorImplicits {
import DataTypeSelector.Prefixes

type Select = (DataFrame, Option[Prefixes], Seq[Option[Flatten]]) => DataFrame

def pure[A](s: Select): AnnotatedStructTypeSelector[A] =
new AnnotatedStructTypeSelector[A] {
val select = s
}
}

trait SelectorImplicits {
implicit val hnilSelector: AnnotatedStructTypeSelector[HNil] =
AnnotatedStructTypeSelector.pure((df, _, _) => df)

implicit def hconsSelector[K <: Symbol, H, T <: HList](
implicit
witness: Witness.Aux[K],
hSelector: Lazy[DataTypeSelector[H]],
tSelector: AnnotatedStructTypeSelector[T]
): AnnotatedStructTypeSelector[FieldType[K, H] :: T] = AnnotatedStructTypeSelector.pure { (df, parentPrefixes, flatten) =>
val fieldName = witness.value.name
val prefixes = parentPrefixes.map(_.map(_.addSuffix(fieldName))).getOrElse(Seq(Prefix(fieldName)))
val childPrefixes = getChildPrefixes(prefixes, flatten.head)
val dfHead = hSelector.value.select(df, Some(childPrefixes))
val dfNested = flatten.head.map { fl =>
val fields = dfHead.schema.fields.map(f => Prefix(f.name)).toSeq
val restCols = fields.filter(f => !childPrefixes.exists(_.isParentOf(f))).map(f => dfHead(f.quotedString))
val structs = childPrefixes.map { p =>
val cols = fields.filter(_.isChildrenOf(p)).map(f => dfHead(f.quotedString).as(f.getSuffix))
struct(cols :_*).as(p.toString)
}
val dfStruct = dfHead.select((structs ++ restCols) :_*)
val nestedCols = getNestedColumns(childPrefixes, dfStruct, fl)
orderedSelect(dfStruct, nestedCols, fields)
}.getOrElse(dfHead)
tSelector.select(dfNested, parentPrefixes, flatten.tail)
}

private def getChildPrefixes(prefixes: Seq[Prefix], flatten: Option[Flatten]) =
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we're also missing the return type here

flatten.map(_ match {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can just do

flatten.map {
  case ...
}

case Flatten(times, _) if times > 1 => (0 until times).flatMap(i => prefixes.map(_.addSuffix(i)))
case Flatten(_, keys) if keys.nonEmpty => keys.flatMap(k => prefixes.map(_.addSuffix(k)))
case Flatten(_, _) => prefixes
}).getOrElse(prefixes)

private def getNestedColumns(prefixes: Seq[Prefix], df: DataFrame, flatten: Flatten): Map[Prefix, Column] =
prefixes.groupBy(_.getParent).map { case (prefix, groupedPrefixes) =>
val colName = prefix.toString
val cols = groupedPrefixes.map(p => df(p.quotedString))
flatten match {
case Flatten(times, _) if times > 1 => (prefix, array(cols :_*).as(colName))
case Flatten(_, keys) if keys.nonEmpty => (prefix, map(interleave(keys.map(lit), cols) :_*).as(colName))
case Flatten(_, _) => (groupedPrefixes.head, cols.head)
}
}(breakOut)

private def orderedSelect(df: DataFrame, nestedCols: Map[Prefix, Column], fields: Seq[Prefix]) = {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing return type here too

@tailrec
def loop(nestedCols: Map[Prefix, Column], fields: Seq[Prefix], cols: Seq[Column]): Seq[Column] = fields match {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we make fields and cols lists instead of seqs?

case Nil => cols.reverse
case hd +: tail => nestedCols.find { case (p, _) => p.isParentOf(hd) } match {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since you have lists you can do hd :: tail, same thing below

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea, unfortunately shapeless overrides the :: definition

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, too bad :(

case Some((p, c)) => loop(nestedCols - p, fields.dropWhile(_.isChildrenOf(p)), c +: cols)
case None => loop(nestedCols, tail, df(hd.quotedString) +: cols)
}
}
val cols = loop(nestedCols, fields, Seq[Column]())
df.select(cols :_*)
}

private def interleave[T](a: Seq[T], b: Seq[T]): Seq[T] = a.zip(b).flatMap(_.toList)

implicit def dfSelector[A, H <: HList, HF <: HList](
implicit
generic: LabelledGeneric.Aux[A, H],
flattenAnnotations: Annotations.Aux[Flatten, A, HF],
hSelector: Lazy[AnnotatedStructTypeSelector[H]],
flattenToList: ToList[HF, Option[Flatten]]
): StructTypeSelector[A] = StructTypeSelector.pure { (df, prefixes) =>
val flatten = flattenAnnotations().toList[Option[Flatten]]
hSelector.value.select(df, prefixes, flatten)
}

implicit val binarySelector: DataTypeSelector[Array[Byte]] = DataTypeSelector.identityDF
implicit val booleanSelector: DataTypeSelector[Boolean] = DataTypeSelector.identityDF
implicit val byteSelector: DataTypeSelector[Byte] = DataTypeSelector.identityDF
implicit val dateSelector: DataTypeSelector[java.sql.Date] = DataTypeSelector.identityDF
implicit val decimalSelector: DataTypeSelector[BigDecimal] = DataTypeSelector.identityDF
implicit val doubleSelector: DataTypeSelector[Double] = DataTypeSelector.identityDF
implicit val floatSelector: DataTypeSelector[Float] = DataTypeSelector.identityDF
implicit val intSelector: DataTypeSelector[Int] = DataTypeSelector.identityDF
implicit val longSelector: DataTypeSelector[Long] = DataTypeSelector.identityDF
implicit val nullSelector: DataTypeSelector[Unit] = DataTypeSelector.identityDF
implicit val shortSelector: DataTypeSelector[Short] = DataTypeSelector.identityDF
implicit val stringSelector: DataTypeSelector[String] = DataTypeSelector.identityDF
implicit val timestampSelector: DataTypeSelector[java.sql.Timestamp] = DataTypeSelector.identityDF
implicit def optionSelector[T]: DataTypeSelector[Option[T]] = DataTypeSelector.identityDF

implicit def traversableOnceSelector[A0, C[_]](
implicit
s: DataTypeSelector[A0],
is: IsTraversableOnce[C[A0]] { type A = A0 }
): DataTypeSelector[C[A0]] = DataTypeSelector.pure { (df, prefixes) =>
s.select(df, prefixes)
}

implicit def mapSelector[K, V](
implicit s: DataTypeSelector[V]
): DataTypeSelector[collection.Map[K, V]] = DataTypeSelector.pure { (df, prefixes) =>
s.select(df, prefixes)
}
}

object DFUtils {
implicit class EnhancedDF(df: DataFrame) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know that these are names of implicits, but I would rename at least the class. EnhancedDF should be something like FlattenedDataFrame. Also, I would split the method, and create a public wrapper around the select only, like selectNested: DataFrame. Then asNested can be implemented with selectNested and as.

def asNested[A : Encoder : StructTypeSelector]: Dataset[A] = StructTypeSelector[A].select(df, None).as[A]
}
}
15 changes: 12 additions & 3 deletions core/src/test/scala/ste/StructTypeEncoderSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,18 @@ class StructTypeEncoderSpec extends FlatSpec with Matchers {
.build

case class Foo(a: String, @Meta(metadata) b: Int)
StructTypeEncoder[Foo].encode shouldBe StructType(
StructField("a", StringType, false) ::
StructField("b", IntegerType, false, metadata) :: Nil
case class Bar(@Flatten(2) a: Seq[Foo], @Flatten(1, Seq("x", "y")) b: collection.Map[String, Foo], @Flatten c: Foo)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we have Map instead of collection.Map?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately not, this fix hasn't been backported to spark 2.1 apache/spark#16161

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, would you be against an upgrade to 2.2.1 in release 0.2.0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wrong, they backported it to 2.1.x, I'm upgrading the patch version, I think we should do the minor update in a separate PR

StructTypeEncoder[Bar].encode shouldBe StructType(
StructField("a.0.a", StringType, false) ::
StructField("a.0.b", IntegerType, false, metadata) ::
StructField("a.1.a", StringType, false) ::
StructField("a.1.b", IntegerType, false, metadata) ::
StructField("b.x.a", StringType, false) ::
StructField("b.x.b", IntegerType, false, metadata) ::
StructField("b.y.a", StringType, false) ::
StructField("b.y.b", IntegerType, false, metadata) ::
StructField("c.a", StringType, false) ::
StructField("c.b", IntegerType, false, metadata) :: Nil
)
}
}
Loading