Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SETL-34] Handle multiple CompoundKeys on the same field #36

Merged
merged 7 commits into from
Jan 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/main/scala/com/jcdecaux/setl/annotation/CompoundKey.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,21 @@ import scala.annotation.StaticAnnotation
*/
@InterfaceStability.Stable
final case class CompoundKey(id: String, position: String) extends StaticAnnotation

object CompoundKey {
private[this] val separator: String = "!@"
import scala.reflect.runtime.{universe => ru}
def serialize(compoundKey: ru.AnnotationApi): String = {

val attributes = compoundKey.tree.children.tail.collect {
case ru.Literal(ru.Constant(attribute)) => attribute.toString
}

attributes.mkString(separator)
}

def deserialize(str: String): CompoundKey = {
val data = str.split(separator)
CompoundKey(data(0), data(1))
}
}
56 changes: 35 additions & 21 deletions src/main/scala/com/jcdecaux/setl/internal/SchemaConverter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ import scala.reflect.runtime.{universe => ru}
*/
object SchemaConverter extends Logging {

private[setl] val COMPOUND_KEY: String = StructAnalyser.COMPOUND_KEY
private[setl] val COLUMN_NAME: String = StructAnalyser.COLUMN_NAME
private[setl] val COMPRESS: String = StructAnalyser.COMPRESS

private[this] val compoundKeySuffix: String = "_key"
private[this] val compoundKeyPrefix: String = "_"
private[this] val compoundKeySeparator: String = "-"
Expand Down Expand Up @@ -72,8 +76,8 @@ object SchemaConverter extends Logging {
.filter {
field =>
val dfContainsFieldName = dfColumns.contains(field.name)
val dfContainsFieldAlias = if (field.metadata.contains(ColumnName.toString())) {
dfColumns.contains(field.metadata.getStringArray(ColumnName.toString()).head)
val dfContainsFieldAlias = if (field.metadata.contains(COLUMN_NAME)) {
dfColumns.contains(field.metadata.getStringArray(COLUMN_NAME).head)
} else {
false
}
Expand Down Expand Up @@ -147,8 +151,8 @@ object SchemaConverter extends Logging {
*/
def replaceDFColNameByFieldName(structType: StructType)(dataFrame: DataFrame): DataFrame = {
val changes = structType
.filter(_.metadata.contains(ColumnName.toString()))
.map(x => x.metadata.getStringArray(ColumnName.toString())(0) -> x.name)
.filter(_.metadata.contains(COLUMN_NAME))
.map(x => x.metadata.getStringArray(COLUMN_NAME)(0) -> x.name)
.toMap

dataFrame.transform(renameColumnsOfDataFrame(changes))
Expand Down Expand Up @@ -183,8 +187,8 @@ object SchemaConverter extends Logging {
*/
def replaceFieldNameByColumnName(structType: StructType)(dataFrame: DataFrame): DataFrame = {
val changes = structType
.filter(_.metadata.contains(ColumnName.toString()))
.map(x => x.name -> x.metadata.getStringArray(ColumnName.toString())(0))
.filter(_.metadata.contains(COLUMN_NAME))
.map(x => x.name -> x.metadata.getStringArray(COLUMN_NAME)(0))
.toMap

dataFrame.transform(renameColumnsOfDataFrame(changes))
Expand All @@ -201,10 +205,9 @@ object SchemaConverter extends Logging {
* Drop all compound key columns
*/
def dropCompoundKeyColumns(structType: StructType)(dataFrame: DataFrame): DataFrame = {

val columnsToDrop = structType
.filter(_.metadata.contains(CompoundKey.toString()))
.map(_.metadata.getStringArray(CompoundKey.toString())(0))
.filter(_.metadata.contains(COMPOUND_KEY))
.map(_.metadata.getStringArray(COMPOUND_KEY)(0))
.toSet

if (columnsToDrop.nonEmpty && dataFrame.columns.intersect(columnsToDrop.toSeq.map(compoundKeyName)).isEmpty) {
Expand Down Expand Up @@ -244,15 +247,26 @@ object SchemaConverter extends Logging {
*/
private[this] def addCompoundKeyColumns(structType: StructType)(dataFrame: DataFrame): DataFrame = {
val keyColumns = structType
.filter(_.metadata.contains(CompoundKey.toString()))
.groupBy(_.metadata.getStringArray(CompoundKey.toString())(0))
.map {
row =>
val sortedCols = row._2
.sortBy(_.metadata.getStringArray(CompoundKey.toString())(1).toInt)
.map(n => functions.col(n.name))
(row._1, sortedCols)
.filter(_.metadata.contains(COMPOUND_KEY))
.flatMap {
structField =>
structField.metadata
.getStringArray(COMPOUND_KEY)
.map {
data =>
val compoundKey = CompoundKey.deserialize(data)
(structField.name, compoundKey)
}
}
.groupBy(_._2.id)
.map {
case (key, fields) =>
val sortedColumns = fields.sortBy(_._2.position).map {
case (colname, _) => functions.col(colname)
}

(key, sortedColumns)
}.toList.sortBy(_._1)

// For each element in keyColumns, add a new column to the input dataFrame
keyColumns
Expand All @@ -270,12 +284,12 @@ object SchemaConverter extends Logging {
*/
def compressColumn(structType: StructType)(dataFrame: DataFrame): DataFrame = {

val columnToCompress = structType.filter(_.metadata.contains(classOf[Compress].getCanonicalName))
val columnToCompress = structType.filter(_.metadata.contains(COMPRESS))

columnToCompress
.foldLeft(dataFrame) {
(df, sf) =>
val compressorName = sf.metadata.getStringArray(classOf[Compress].getCanonicalName).head
val compressorName = sf.metadata.getStringArray(COMPRESS).head
val compressor = Class.forName(compressorName).newInstance().asInstanceOf[Compressor]
val compress: String => Array[Byte] = input => compressor.compress(input)
val compressUDF = functions.udf(compress)
Expand All @@ -292,12 +306,12 @@ object SchemaConverter extends Logging {
*/
def decompressColumn(structType: StructType)(dataFrame: DataFrame): DataFrame = {

val columnToDecompress = structType.filter(_.metadata.contains(classOf[Compress].getCanonicalName))
val columnToDecompress = structType.filter(_.metadata.contains(COMPRESS))

columnToDecompress
.foldLeft(dataFrame) {
(df, sf) => {
val compressorName = sf.metadata.getStringArray(classOf[Compress].getCanonicalName).head
val compressorName = sf.metadata.getStringArray(COMPRESS).head
val compressor = Class.forName(compressorName).newInstance().asInstanceOf[Compressor]
val decompress: Array[Byte] => String = input => compressor.decompress(input)
val decompressUDF = functions.udf(decompress)
Expand Down
34 changes: 26 additions & 8 deletions src/main/scala/com/jcdecaux/setl/internal/StructAnalyser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ import scala.reflect.runtime.{universe => ru}

object StructAnalyser extends Logging {

private[setl] val COMPOUND_KEY: String = classOf[CompoundKey].getCanonicalName
private[setl] val COLUMN_NAME: String = classOf[ColumnName].getCanonicalName
private[setl] val COMPRESS: String = classOf[Compress].getCanonicalName

/**
* Analyse the metadata of the generic type T. Fetch information for its annotated fields.
*
Expand Down Expand Up @@ -40,25 +44,31 @@ object StructAnalyser extends Logging {
val value = columnName.tree.children.tail.collectFirst {
case ru.Literal(ru.Constant(name)) => name.toString
}
(ColumnName.toString(), Array(value.get)) // (ColumnName, ["alias"])
(COLUMN_NAME, Array(value.get)) // (ColumnName, ["alias"])

// Case where the field has annotation `CompoundKey`
case compoundKey: ru.AnnotationApi if compoundKey.tree.tpe =:= ru.typeOf[CompoundKey] =>
val attributes = Some(compoundKey.tree.children.tail.collect {
case ru.Literal(ru.Constant(attribute)) => attribute.toString
})
val attribute = CompoundKey.serialize(compoundKey)
// All compound key column should not be nullable
nullable = false
(CompoundKey.toString(), attributes.get.toArray) // (ColumnName, ["id", "position"])
(COMPOUND_KEY, Array(attribute)) // (ColumnName, ["id", "position"])

case compress: ru.AnnotationApi if compress.tree.tpe =:= ru.typeOf[Compress] =>
val compressor = columnToBeCompressed.find(_._1 == index).get._2.getCanonicalName
(classOf[Compress].getCanonicalName, Array(compressor)) // (com.jcdecaux.setl.xxxx, ["compressor_canonical_name"])
(COMPRESS, Array(compressor)) // (com.jcdecaux.setl.xxxx, ["compressor_canonical_name"])

}.toMap
}
.groupBy(_._1)
.map {
case (group, elements) => (group, elements.flatMap(_._2))
}

val metadataBuilder = new MetadataBuilder()
annotations.foreach(annoData => metadataBuilder.putStringArray(annoData._1, annoData._2))
annotations.foreach {
annotationData =>
verifyAnnotation(annotationData._1, annotationData._2)
metadataBuilder.putStringArray(annotationData._1, annotationData._2.toArray)
}

StructField(field.name.toString, dataType, nullable, metadataBuilder.build())
}
Expand Down Expand Up @@ -87,4 +97,12 @@ object StructAnalyser extends Logging {
}
}

private[this] def verifyAnnotation(annotation: String, data: List[String]): Unit = {
if (annotation == COLUMN_NAME) {
require(data.length == 1, "There should not be more than one ColumnName annotation")
} else if (annotation == COMPRESS) {
require(data.length == 1, "There should not be more than one Compress annotation")
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -205,14 +205,14 @@ object SparkRepository {
*/
private[repository] def handleConditions(conditions: Set[Condition], schema: StructType): Set[Condition] = {

val columnWithAlias = schema.filter(_.metadata.contains(ColumnName.toString()))
val binaryColumns = schema.filter(_.metadata.contains(classOf[Compress].getCanonicalName))
val columnWithAlias = schema.filter(_.metadata.contains(SchemaConverter.COLUMN_NAME))
val binaryColumns = schema.filter(_.metadata.contains(SchemaConverter.COMPRESS))


val binaryColumnNames = binaryColumns.map(_.name)
val aliasBinaryColumns = binaryColumns
.filter(bc => columnWithAlias.map(_.name).contains(bc.name))
.map(bc => bc.metadata.getStringArray(ColumnName.toString()).head)
.map(bc => bc.metadata.getStringArray(SchemaConverter.COLUMN_NAME).head)

conditions
.map {
Expand All @@ -236,7 +236,7 @@ object SparkRepository {
*/
columnWithAlias.foreach {
col =>
val alias = col.metadata.getStringArray(ColumnName.toString()).headOption
val alias = col.metadata.getStringArray(SchemaConverter.COLUMN_NAME).headOption
if (alias.nonEmpty) {
sqlString = sqlString.replace(s"`${col.name}`", s"`${alias.get}`")
}
Expand All @@ -254,7 +254,7 @@ object SparkRepository {
*/
columnWithAlias.find(_.name == cond.key) match {
case Some(a) =>
cond.copy(key = a.metadata.getStringArray(ColumnName.toString()).head)
cond.copy(key = a.metadata.getStringArray(SchemaConverter.COLUMN_NAME).head)
case _ => cond
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class SchemaConverterSuite extends AnyFunSuite {

val df = SchemaConverter.toDF(ds)
df.show()
assert(df.columns === Array("a", "b", "c", "_sort_key", "_primary_key"))
assert(df.columns.toSet === Set("a", "b", "c", "_sort_key", "_primary_key"))
assert(df.collect().map(_.getAs[String]("_primary_key")) === Array("a-1", "b-2", "c-3"))
assert(df.filter($"_primary_key" === "c-3").collect().length === 1)

Expand All @@ -144,6 +144,23 @@ class SchemaConverterSuite extends AnyFunSuite {

}

test("[SETL-34] SchemaConverter should handle multi CompoundKeys on the same field") {
val spark: SparkSession = new SparkSessionBuilder().setEnv("local").build().get()
import spark.implicits._

val ds = Seq(
MultipleCompoundKeyTest("a", "1", "A"),
MultipleCompoundKeyTest("b", "2", "B"),
MultipleCompoundKeyTest("c", "3", "C")
).toDS()

val df = SchemaConverter.toDF(ds)

assert(df.columns === Array("col1", "col2", "COLUMN_3", "_part_key", "_sort_key"))
assert(df.select($"_part_key".as[String]).collect() === Array("a-A", "b-B", "c-C"))
assert(df.select($"_sort_key".as[String]).collect() === Array("a-1", "b-2", "c-3"))
}

test("Schema converter should add missing nullable columns in the DF-DS conversion") {
val spark: SparkSession = new SparkSessionBuilder().setEnv("local").build().get()
import spark.implicits._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,20 @@ class StructAnalyserSuite extends AnyFunSuite {
val schema: StructType = StructAnalyser.analyseSchema[TestStructAnalyser]

test("StructAnalyser should be able to handle @ColumnName") {
val fields = schema.filter(_.metadata.contains(ColumnName.toString()))
val fields = schema.filter(_.metadata.contains(classOf[ColumnName].getCanonicalName))

assert(fields.length === 1)
assert(fields.head.name === "col1")
assert(fields.head.metadata.getStringArray(ColumnName.toString()) === Array("alias1"))
assert(fields.head.metadata.getStringArray(classOf[ColumnName].getCanonicalName) === Array("alias1"))

}

test("StructAnalyser should be able to handle @CompoundKey") {
val fields = schema.filter(_.metadata.contains(CompoundKey.toString()))
val fields = schema.filter(_.metadata.contains(classOf[CompoundKey].getCanonicalName))

assert(fields.length === 2)
assert(fields.map(_.name) === Array("col2", "col22"))
assert(fields.map(_.metadata.getStringArray(CompoundKey.toString())).map(_ (0)) === List("test", "test"))
assert(fields.map(_.metadata.getStringArray(CompoundKey.toString())).map(_ (1)) === List("1", "2"))

assert(fields.map(_.metadata.getStringArray(classOf[CompoundKey].getCanonicalName)).map(_ (0)) === List("test!@1", "test!@2"))
}

test("StructAnalyser should be able to handle @Compress") {
Expand All @@ -50,4 +48,19 @@ class StructAnalyserSuite extends AnyFunSuite {
)
}

test("[SETL-34] StructAnalyser should handle multiple @CompoundKey annotations") {
val structType = StructAnalyser.analyseSchema[TestClasses.MultipleCompoundKeyTest]
structType.foreach { x =>
println(s"name: ${x.name}, type: ${x.dataType}, meta: ${x.metadata}")
}

assert(structType.find(_.name == "col1").get.metadata.getStringArray(classOf[CompoundKey].getCanonicalName) === Array("sort!@1","part!@1"))
assert(structType.find(_.name == "col2").get.metadata.getStringArray(classOf[CompoundKey].getCanonicalName) === Array("sort!@2"))
assert(structType.find(_.name == "col3").get.metadata.getStringArray(classOf[CompoundKey].getCanonicalName) === Array("part!@2"))
}

test("[SETL-34] StructAnalyser should throw exception when there are more than one ColumnName annotation") {
assertThrows[IllegalArgumentException](StructAnalyser.analyseSchema[TestClasses.WrongClass])
}

}
6 changes: 6 additions & 0 deletions src/test/scala/com/jcdecaux/setl/internal/TestClasses.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ import com.jcdecaux.setl.transformation.Factory

object TestClasses {

case class WrongClass(@ColumnName("1") @ColumnName("2") col1: String)

case class MultipleCompoundKeyTest(@CompoundKey("sort", "1") @CompoundKey("part", "1") col1: String,
@CompoundKey("sort", "2") col2: String,
@CompoundKey("part", "2") @ColumnName("COLUMN_3") col3: String)

case class InnerClass(innerCol1: String, innerCol2: String)

case class TestCompression(@ColumnName("dqsf") col1: String,
Expand Down