From 45b4af378f2e1eda83193a809df5291cd0420876 Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Sun, 3 Nov 2024 14:04:28 -0800 Subject: [PATCH 01/37] Changes --- .../spark/utils/DataFramePrinter.scala | 144 ++++++++++++++++++ 1 file changed, 144 insertions(+) create mode 100644 spark/src/main/scala/ai/chronon/spark/utils/DataFramePrinter.scala diff --git a/spark/src/main/scala/ai/chronon/spark/utils/DataFramePrinter.scala b/spark/src/main/scala/ai/chronon/spark/utils/DataFramePrinter.scala new file mode 100644 index 0000000000..764b3fc1e3 --- /dev/null +++ b/spark/src/main/scala/ai/chronon/spark/utils/DataFramePrinter.scala @@ -0,0 +1,144 @@ +package ai.chronon.spark.utils + +import org.apache.commons.lang3.StringUtils +import org.apache.spark.sql.DataFrame + +import java.util.logging.Logger + + +// there was no way to print a message, and the contents of the dataframe together +// the methods to convert a dataframe into a string were private inside spark +// so pulling it out +object DataFramePrinter { + private val fullWidthRegex = ("""[""" + + // scalastyle:off nonascii + "\u1100-\u115F" + + "\u2E80-\uA4CF" + + "\uAC00-\uD7A3" + + "\uF900-\uFAFF" + + "\uFE10-\uFE19" + + "\uFE30-\uFE6F" + + "\uFF00-\uFF60" + + "\uFFE0-\uFFE6" + + // scalastyle:on nonascii + """]""").r + + private def stringHalfWidth(str: String): Int = { + if (str == null) 0 else str.length + fullWidthRegex.findAllIn(str).size + } + + private def escapeMetaCharacters(str: String): String = { + str.replaceAll("\n", "\\\\n") + .replaceAll("\r", "\\\\r") + .replaceAll("\t", "\\\\t") + .replaceAll("\f", "\\\\f") + .replaceAll("\b", "\\\\b") + .replaceAll("\u000B", "\\\\v") + .replaceAll("\u0007", "\\\\a") + } + + def showString( df: DataFrame, + numRows: Int = 10, + truncate: Int = 20, + vertical: Boolean = false): String = { + val data = df.take(numRows + 1) + + // For array values, replace Seq and Array with square brackets + // For cells that are beyond `truncate` characters, replace it with the + // first `truncate-3` and "..." + val tmpRows = df.schema.fieldNames.map(escapeMetaCharacters).toSeq +: data.map { row => + row.toSeq.map { cell => + assert(cell != null, "ToPrettyString is not nullable and should not return null value") + // Escapes meta-characters not to break the `showString` format + val str = escapeMetaCharacters(cell.toString) + if (truncate > 0 && str.length > truncate) { + // do not show ellipses for strings shorter than 4 characters. + if (truncate < 4) str.substring(0, truncate) + else str.substring(0, truncate - 3) + "..." + } else { + str + } + }: Seq[String] + } + + val hasMoreData = tmpRows.length - 1 > numRows + val rows = tmpRows.take(numRows + 1) + + val sb = new StringBuilder + val numCols = df.schema.fieldNames.length + // We set a minimum column width at '3' + val minimumColWidth = 3 + + if (!vertical) { + // Initialise the width of each column to a minimum value + val colWidths = Array.fill(numCols)(minimumColWidth) + + // Compute the width of each column + for (row <- rows) { + for ((cell, i) <- row.zipWithIndex) { + colWidths(i) = math.max(colWidths(i), stringHalfWidth(cell)) + } + } + + val paddedRows = rows.map { row => + row.zipWithIndex.map { case (cell, i) => + if (truncate > 0) { + StringUtils.leftPad(cell, colWidths(i) - stringHalfWidth(cell) + cell.length) + } else { + StringUtils.rightPad(cell, colWidths(i) - stringHalfWidth(cell) + cell.length) + } + } + } + + // Create SeparateLine + val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString() + + // column names + paddedRows.head.addString(sb, "|", "|", "|\n") + sb.append(sep) + + // data + paddedRows.tail.foreach(_.addString(sb, "|", "|", "|\n")) + sb.append(sep) + } else { + // Extended display mode enabled + val fieldNames = rows.head + val dataRows = rows.tail + + // Compute the width of field name and data columns + val fieldNameColWidth = fieldNames.foldLeft(minimumColWidth) { case (curMax, fieldName) => + math.max(curMax, stringHalfWidth(fieldName)) + } + val dataColWidth = dataRows.foldLeft(minimumColWidth) { case (curMax, row) => + math.max(curMax, row.map(cell => stringHalfWidth(cell)).max) + } + + dataRows.zipWithIndex.foreach { case (row, i) => + // "+ 5" in size means a character length except for padded names and data + val rowHeader = StringUtils.rightPad( + s"-RECORD $i", fieldNameColWidth + dataColWidth + 5, "-") + sb.append(rowHeader).append("\n") + row.zipWithIndex.map { case (cell, j) => + val fieldName = StringUtils.rightPad(fieldNames(j), + fieldNameColWidth - stringHalfWidth(fieldNames(j)) + fieldNames(j).length) + val data = StringUtils.rightPad(cell, + dataColWidth - stringHalfWidth(cell) + cell.length) + s" $fieldName | $data " + }.addString(sb, "", "\n", "\n") + } + } + + // Print a footer + if (vertical && rows.tail.isEmpty) { + // In a vertical mode, print an empty row set explicitly + sb.append("(0 rows)\n") + } else if (hasMoreData) { + // For Data that has more than "numRows" records + val rowsString = if (numRows == 1) "row" else "rows" + sb.append(s"only showing top $numRows $rowsString\n") + } + + sb.toString() + } + +} From e3d21152e7f0dfa9ac3f50494c6d1b68936d09b1 Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Sun, 3 Nov 2024 14:06:46 -0800 Subject: [PATCH 02/37] changes so far --- .../spark/utils/DataFramePrinter.scala | 61 ++++++++++--------- 1 file changed, 31 insertions(+), 30 deletions(-) diff --git a/spark/src/main/scala/ai/chronon/spark/utils/DataFramePrinter.scala b/spark/src/main/scala/ai/chronon/spark/utils/DataFramePrinter.scala index 764b3fc1e3..cc7ce60de3 100644 --- a/spark/src/main/scala/ai/chronon/spark/utils/DataFramePrinter.scala +++ b/spark/src/main/scala/ai/chronon/spark/utils/DataFramePrinter.scala @@ -3,9 +3,6 @@ package ai.chronon.spark.utils import org.apache.commons.lang3.StringUtils import org.apache.spark.sql.DataFrame -import java.util.logging.Logger - - // there was no way to print a message, and the contents of the dataframe together // the methods to convert a dataframe into a string were private inside spark // so pulling it out @@ -28,7 +25,8 @@ object DataFramePrinter { } private def escapeMetaCharacters(str: String): String = { - str.replaceAll("\n", "\\\\n") + str + .replaceAll("\n", "\\\\n") .replaceAll("\r", "\\\\r") .replaceAll("\t", "\\\\t") .replaceAll("\f", "\\\\f") @@ -37,10 +35,7 @@ object DataFramePrinter { .replaceAll("\u0007", "\\\\a") } - def showString( df: DataFrame, - numRows: Int = 10, - truncate: Int = 20, - vertical: Boolean = false): String = { + def showString(df: DataFrame, numRows: Int = 10, truncate: Int = 20, vertical: Boolean = false): String = { val data = df.take(numRows + 1) // For array values, replace Seq and Array with square brackets @@ -81,12 +76,13 @@ object DataFramePrinter { } val paddedRows = rows.map { row => - row.zipWithIndex.map { case (cell, i) => - if (truncate > 0) { - StringUtils.leftPad(cell, colWidths(i) - stringHalfWidth(cell) + cell.length) - } else { - StringUtils.rightPad(cell, colWidths(i) - stringHalfWidth(cell) + cell.length) - } + row.zipWithIndex.map { + case (cell, i) => + if (truncate > 0) { + StringUtils.leftPad(cell, colWidths(i) - stringHalfWidth(cell) + cell.length) + } else { + StringUtils.rightPad(cell, colWidths(i) - stringHalfWidth(cell) + cell.length) + } } } @@ -106,25 +102,30 @@ object DataFramePrinter { val dataRows = rows.tail // Compute the width of field name and data columns - val fieldNameColWidth = fieldNames.foldLeft(minimumColWidth) { case (curMax, fieldName) => - math.max(curMax, stringHalfWidth(fieldName)) + val fieldNameColWidth = fieldNames.foldLeft(minimumColWidth) { + case (curMax, fieldName) => + math.max(curMax, stringHalfWidth(fieldName)) } - val dataColWidth = dataRows.foldLeft(minimumColWidth) { case (curMax, row) => - math.max(curMax, row.map(cell => stringHalfWidth(cell)).max) + val dataColWidth = dataRows.foldLeft(minimumColWidth) { + case (curMax, row) => + math.max(curMax, row.map(cell => stringHalfWidth(cell)).max) } - dataRows.zipWithIndex.foreach { case (row, i) => - // "+ 5" in size means a character length except for padded names and data - val rowHeader = StringUtils.rightPad( - s"-RECORD $i", fieldNameColWidth + dataColWidth + 5, "-") - sb.append(rowHeader).append("\n") - row.zipWithIndex.map { case (cell, j) => - val fieldName = StringUtils.rightPad(fieldNames(j), - fieldNameColWidth - stringHalfWidth(fieldNames(j)) + fieldNames(j).length) - val data = StringUtils.rightPad(cell, - dataColWidth - stringHalfWidth(cell) + cell.length) - s" $fieldName | $data " - }.addString(sb, "", "\n", "\n") + dataRows.zipWithIndex.foreach { + case (row, i) => + // "+ 5" in size means a character length except for padded names and data + val rowHeader = StringUtils.rightPad(s"-RECORD $i", fieldNameColWidth + dataColWidth + 5, "-") + sb.append(rowHeader).append("\n") + row.zipWithIndex + .map { + case (cell, j) => + val fieldName = + StringUtils.rightPad(fieldNames(j), + fieldNameColWidth - stringHalfWidth(fieldNames(j)) + fieldNames(j).length) + val data = StringUtils.rightPad(cell, dataColWidth - stringHalfWidth(cell) + cell.length) + s" $fieldName | $data " + } + .addString(sb, "", "\n", "\n") } } From c0e5d8fabc3082001d2616b5d73b5ec0499bc827 Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Sun, 3 Nov 2024 14:21:56 -0800 Subject: [PATCH 03/37] remove unused file --- .../spark/utils/DataFramePrinter.scala | 145 ------------------ 1 file changed, 145 deletions(-) delete mode 100644 spark/src/main/scala/ai/chronon/spark/utils/DataFramePrinter.scala diff --git a/spark/src/main/scala/ai/chronon/spark/utils/DataFramePrinter.scala b/spark/src/main/scala/ai/chronon/spark/utils/DataFramePrinter.scala deleted file mode 100644 index cc7ce60de3..0000000000 --- a/spark/src/main/scala/ai/chronon/spark/utils/DataFramePrinter.scala +++ /dev/null @@ -1,145 +0,0 @@ -package ai.chronon.spark.utils - -import org.apache.commons.lang3.StringUtils -import org.apache.spark.sql.DataFrame - -// there was no way to print a message, and the contents of the dataframe together -// the methods to convert a dataframe into a string were private inside spark -// so pulling it out -object DataFramePrinter { - private val fullWidthRegex = ("""[""" + - // scalastyle:off nonascii - "\u1100-\u115F" + - "\u2E80-\uA4CF" + - "\uAC00-\uD7A3" + - "\uF900-\uFAFF" + - "\uFE10-\uFE19" + - "\uFE30-\uFE6F" + - "\uFF00-\uFF60" + - "\uFFE0-\uFFE6" + - // scalastyle:on nonascii - """]""").r - - private def stringHalfWidth(str: String): Int = { - if (str == null) 0 else str.length + fullWidthRegex.findAllIn(str).size - } - - private def escapeMetaCharacters(str: String): String = { - str - .replaceAll("\n", "\\\\n") - .replaceAll("\r", "\\\\r") - .replaceAll("\t", "\\\\t") - .replaceAll("\f", "\\\\f") - .replaceAll("\b", "\\\\b") - .replaceAll("\u000B", "\\\\v") - .replaceAll("\u0007", "\\\\a") - } - - def showString(df: DataFrame, numRows: Int = 10, truncate: Int = 20, vertical: Boolean = false): String = { - val data = df.take(numRows + 1) - - // For array values, replace Seq and Array with square brackets - // For cells that are beyond `truncate` characters, replace it with the - // first `truncate-3` and "..." - val tmpRows = df.schema.fieldNames.map(escapeMetaCharacters).toSeq +: data.map { row => - row.toSeq.map { cell => - assert(cell != null, "ToPrettyString is not nullable and should not return null value") - // Escapes meta-characters not to break the `showString` format - val str = escapeMetaCharacters(cell.toString) - if (truncate > 0 && str.length > truncate) { - // do not show ellipses for strings shorter than 4 characters. - if (truncate < 4) str.substring(0, truncate) - else str.substring(0, truncate - 3) + "..." - } else { - str - } - }: Seq[String] - } - - val hasMoreData = tmpRows.length - 1 > numRows - val rows = tmpRows.take(numRows + 1) - - val sb = new StringBuilder - val numCols = df.schema.fieldNames.length - // We set a minimum column width at '3' - val minimumColWidth = 3 - - if (!vertical) { - // Initialise the width of each column to a minimum value - val colWidths = Array.fill(numCols)(minimumColWidth) - - // Compute the width of each column - for (row <- rows) { - for ((cell, i) <- row.zipWithIndex) { - colWidths(i) = math.max(colWidths(i), stringHalfWidth(cell)) - } - } - - val paddedRows = rows.map { row => - row.zipWithIndex.map { - case (cell, i) => - if (truncate > 0) { - StringUtils.leftPad(cell, colWidths(i) - stringHalfWidth(cell) + cell.length) - } else { - StringUtils.rightPad(cell, colWidths(i) - stringHalfWidth(cell) + cell.length) - } - } - } - - // Create SeparateLine - val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString() - - // column names - paddedRows.head.addString(sb, "|", "|", "|\n") - sb.append(sep) - - // data - paddedRows.tail.foreach(_.addString(sb, "|", "|", "|\n")) - sb.append(sep) - } else { - // Extended display mode enabled - val fieldNames = rows.head - val dataRows = rows.tail - - // Compute the width of field name and data columns - val fieldNameColWidth = fieldNames.foldLeft(minimumColWidth) { - case (curMax, fieldName) => - math.max(curMax, stringHalfWidth(fieldName)) - } - val dataColWidth = dataRows.foldLeft(minimumColWidth) { - case (curMax, row) => - math.max(curMax, row.map(cell => stringHalfWidth(cell)).max) - } - - dataRows.zipWithIndex.foreach { - case (row, i) => - // "+ 5" in size means a character length except for padded names and data - val rowHeader = StringUtils.rightPad(s"-RECORD $i", fieldNameColWidth + dataColWidth + 5, "-") - sb.append(rowHeader).append("\n") - row.zipWithIndex - .map { - case (cell, j) => - val fieldName = - StringUtils.rightPad(fieldNames(j), - fieldNameColWidth - stringHalfWidth(fieldNames(j)) + fieldNames(j).length) - val data = StringUtils.rightPad(cell, dataColWidth - stringHalfWidth(cell) + cell.length) - s" $fieldName | $data " - } - .addString(sb, "", "\n", "\n") - } - } - - // Print a footer - if (vertical && rows.tail.isEmpty) { - // In a vertical mode, print an empty row set explicitly - sb.append("(0 rows)\n") - } else if (hasMoreData) { - // For Data that has more than "numRows" records - val rowsString = if (numRows == 1) "row" else "rows" - sb.append(s"only showing top $numRows $rowsString\n") - } - - sb.toString() - } - -} From e718f91e8a93ff5962ce669c5dd25be7b44496ea Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Sun, 3 Nov 2024 14:40:25 -0800 Subject: [PATCH 04/37] fix --- .../scala/ai/chronon/api/ColorPrinter.scala | 24 ------------------- 1 file changed, 24 deletions(-) delete mode 100644 api/src/main/scala/ai/chronon/api/ColorPrinter.scala diff --git a/api/src/main/scala/ai/chronon/api/ColorPrinter.scala b/api/src/main/scala/ai/chronon/api/ColorPrinter.scala deleted file mode 100644 index 4d1dc57c50..0000000000 --- a/api/src/main/scala/ai/chronon/api/ColorPrinter.scala +++ /dev/null @@ -1,24 +0,0 @@ -package ai.chronon.api - -object ColorPrinter { - // ANSI escape codes for text colors - private val ANSI_RESET = "\u001B[0m" - - // Colors chosen for visibility on both dark and light backgrounds - // More muted colors that should still be visible on various backgrounds - private val ANSI_RED = "\u001B[38;5;131m" // Muted red (soft burgundy) - private val ANSI_BLUE = "\u001B[38;5;32m" // Medium blue - private val ANSI_YELLOW = "\u001B[38;5;172m" // Muted Orange - private val ANSI_GREEN = "\u001B[38;5;28m" // Forest green - - private val BOLD = "\u001B[1m" - - implicit class ColorString(val s: String) extends AnyVal { - def red: String = s"$ANSI_RED$s$ANSI_RESET" - def blue: String = s"$ANSI_BLUE$s$ANSI_RESET" - def yellow: String = s"$ANSI_YELLOW$s$ANSI_RESET" - def green: String = s"$ANSI_GREEN$s$ANSI_RESET" - def low: String = s.toLowerCase - def highlight: String = s"$BOLD$ANSI_RED$s$ANSI_RESET" - } -} From d3ff253b7a78ed922f9028c5a1c768390ba187e8 Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Sun, 3 Nov 2024 14:46:02 -0800 Subject: [PATCH 05/37] adding back color printer --- .../scala/ai/chronon/api/ColorPrinter.scala | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 api/src/main/scala/ai/chronon/api/ColorPrinter.scala diff --git a/api/src/main/scala/ai/chronon/api/ColorPrinter.scala b/api/src/main/scala/ai/chronon/api/ColorPrinter.scala new file mode 100644 index 0000000000..bf44fa2d13 --- /dev/null +++ b/api/src/main/scala/ai/chronon/api/ColorPrinter.scala @@ -0,0 +1,21 @@ +package ai.chronon.api + +object ColorPrinter { + // ANSI escape codes for text colors + private val ANSI_RESET = "\u001B[0m" + + // Colors chosen for visibility on both dark and light backgrounds + // More muted colors that should still be visible on various backgrounds + private val ANSI_RED = "\u001B[38;5;131m" // Muted red (soft burgundy) + private val ANSI_BLUE = "\u001B[38;5;32m" // Medium blue + private val ANSI_YELLOW = "\u001B[38;5;172m" // Muted Orange + private val ANSI_GREEN = "\u001B[38;5;28m" // Forest green + + implicit class ColorString(val s: String) extends AnyVal { + def red: String = s"$ANSI_RED$s$ANSI_RESET" + def blue: String = s"$ANSI_BLUE$s$ANSI_RESET" + def yellow: String = s"$ANSI_YELLOW$s$ANSI_RESET" + def green: String = s"$ANSI_GREEN$s$ANSI_RESET" + def low: String = s.toLowerCase + } +} \ No newline at end of file From bbf1ddd201983b774128d4d8c03f01a1534ae962 Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Sun, 3 Nov 2024 17:39:58 -0800 Subject: [PATCH 06/37] scalafmt fix --- api/src/main/scala/ai/chronon/api/ColorPrinter.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/src/main/scala/ai/chronon/api/ColorPrinter.scala b/api/src/main/scala/ai/chronon/api/ColorPrinter.scala index bf44fa2d13..e779e3eaf1 100644 --- a/api/src/main/scala/ai/chronon/api/ColorPrinter.scala +++ b/api/src/main/scala/ai/chronon/api/ColorPrinter.scala @@ -18,4 +18,4 @@ object ColorPrinter { def green: String = s"$ANSI_GREEN$s$ANSI_RESET" def low: String = s.toLowerCase } -} \ No newline at end of file +} From d61c83a486f69276402308e3b51d6ce85d5415b3 Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Wed, 6 Nov 2024 18:31:57 -0800 Subject: [PATCH 07/37] assign intervals --- .../online/stats/DistanceMetrics.scala | 147 ++++++++++++++++++ 1 file changed, 147 insertions(+) create mode 100644 online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala diff --git a/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala b/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala new file mode 100644 index 0000000000..0620f73aa0 --- /dev/null +++ b/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala @@ -0,0 +1,147 @@ +package ai.chronon.online.stats + +import ai.chronon.api.ColorPrinter.ColorString +import ai.chronon.api.Window + +import scala.math._ + + +object DistanceMetrics { + + // TODO move this to unit test + def main(args: Array[String]): Unit = { + val A = Array(0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200).map( + _.toDouble) + val B = Array(5, 15, 25, 35, 45, 55, 65, 75, 85, 95, 115, 115, 115, 135, 145, 155, 165, 1175, 1205, 1205, 1205).map( + _.toDouble) + + val jsd = jensenShannonDivergence(A, B) + println(f"The Jensen-Shannon Divergence between distributions A and B is: $jsd%.5f") + val psi = populationStabilityIndex(A, B) + println(f"The Population Stability Index between distributions A and B is: $psi%.5f") + val hd = hellingerDistance(A, B) + println(f"The Hellinger Distance between distributions A and B is: $hd%.5f") + + + // format: off + // aligned vertically for easier reasoning + val ptiles = Array( 1, 4, 6,6,6, 8, 9 ) + val breaks = Array(0, 1, 2, 3, 5, 6, 7, 8, 9, 10) + // format: on + + //val interval = 0.25 + val expected = Array(0.0, 1.0/3.0 , 1.0/3.0, (1.0)/(3.0) + (1.0)/(2.0), (1.0)/(2.0), 2.5, 0.5, 1, 0) + + val result = AssignIntervals.on(ptiles = ptiles.map(_.toDouble), breaks = breaks.map(_.toDouble)) + + expected.zip(result).foreach{case (e, r) => println(s"exp: $e res: $r")} + + } + + // all same size - used for distance computation and for drill down display in front-end + case class Distributions(p: Array[Double], q: Array[Double], bins: Array[String]) + + case class Comparison[T](previous: T, current: T, timeDelta: Window) + + + def functionBuilder[T](binningFunc: Comparison[T] => Distributions, distanceFunc: Distributions => Double): Comparison[T] => Double = { + c => + val dists = binningFunc(c) + val distance = distanceFunc(dists) + distance + } + + + def hellingerDistance(p: Array[Double], q: Array[Double]): Double = { + val (pProbs, qProbs) = computePDFs(p, q) + + sqrt( + pProbs + .zip(qProbs) + .map { + case (pi, qi) => + pow(sqrt(pi) - sqrt(qi), 2) + } + .sum / 2) + } + + def populationStabilityIndex(p: Array[Double], q: Array[Double]): Double = { + val (pProbs, qProbs) = computePDFs(p, q) + + pProbs + .zip(qProbs) + .map { + case (pi, qi) => + if (pi > 0 && qi > 0) (qi - pi) * log(qi / pi) + else 0.0 // Handle zero probabilities + } + .sum + } + + def jensenShannonDivergence(p: Array[Double], q: Array[Double]): Double = { + // Step 1: compute probability distributions on the same x-axis + val (pdfP, pdfQ) = computePDFs(p, q) + + // Step 2: compute the mixture distribution M + val pdfM = pdfP.zip(pdfQ).map { case (a, b) => 0.5 * (a + b) } + + // Step 3: compute divergence + val klAM = klDivergence(pdfP, pdfM) + val klBM = klDivergence(pdfQ, pdfM) + + 0.5 * (klAM + klBM) + } + + def computePDFs(p: Array[Double], q: Array[Double]): (Array[Double], Array[Double]) = { + val breakpoints = (p ++ q).distinct.sorted + + val pdfP = computePDF(p, breakpoints).map(_.value) + val pdfQ = computePDF(q, breakpoints).map(_.value) + + pdfP -> pdfQ + } + + case class Mass(value: Double, isPointMass: Boolean) + + def computePDF(percentiles: Array[Double], breaks: Array[Double]): Array[Mass] = { + val n = percentiles.length + require(percentiles.length > 2, "Need at-least 3 percentiles to plot a distribution") + + val interval: Double = 1.toDouble / (n - 1.0) + + def mass(i: Int, eh: Int): Mass = { + def indexMass(i: Int): Double = interval / (percentiles(i) - percentiles(i-1)) + val isPointMass = eh > 1 + val m = (i, eh) match { + case (0, _) => 0.0 // before range + case (x, 0) if x>=n => 0.0 // after range + case (_, e) if e > 1 => (e - 1) * interval // point mass + case (x, 1) if x==n => indexMass(n-1) // exactly at end of range + case (x, _) => indexMass(x) // somewhere in between + } + Mass(m, isPointMass) + } + + var i = 0 + breaks.map { break => + var equalityHits = 0 + while (i < percentiles.length && percentiles(i) <= break) { + if (percentiles(i) == break) equalityHits += 1 + i += 1 + } + mass(i, equalityHits) + } + } + + def klDivergence(p: Array[Double], q: Array[Double]): Double = { + require(p.length == q.length, s"Inputs are of different length ${p.length}, ${q.length}") + var i = 0 + var result = 0.0 + while(i < p.length) { + val inc = if (p(i)> 0 && q(i) > 0) p(i) * math.log(p(i) / q(i)) else 0 + result += inc + i += 1 + } + result + } +} From ce394283beaa95ece1564fe7f06c7ca74f454bf3 Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Wed, 6 Nov 2024 18:33:02 -0800 Subject: [PATCH 08/37] assign intervals --- .../online/stats/DistanceMetrics.scala | 36 ++++++++----------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala b/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala index 0620f73aa0..a7896efc94 100644 --- a/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala +++ b/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala @@ -1,11 +1,8 @@ package ai.chronon.online.stats - -import ai.chronon.api.ColorPrinter.ColorString import ai.chronon.api.Window import scala.math._ - object DistanceMetrics { // TODO move this to unit test @@ -22,7 +19,6 @@ object DistanceMetrics { val hd = hellingerDistance(A, B) println(f"The Hellinger Distance between distributions A and B is: $hd%.5f") - // format: off // aligned vertically for easier reasoning val ptiles = Array( 1, 4, 6,6,6, 8, 9 ) @@ -30,11 +26,11 @@ object DistanceMetrics { // format: on //val interval = 0.25 - val expected = Array(0.0, 1.0/3.0 , 1.0/3.0, (1.0)/(3.0) + (1.0)/(2.0), (1.0)/(2.0), 2.5, 0.5, 1, 0) + val expected = Array(0.0, 1.0 / 3.0, 1.0 / 3.0, (1.0) / (3.0) + (1.0) / (2.0), (1.0) / (2.0), 2.5, 0.5, 1, 0) val result = AssignIntervals.on(ptiles = ptiles.map(_.toDouble), breaks = breaks.map(_.toDouble)) - expected.zip(result).foreach{case (e, r) => println(s"exp: $e res: $r")} + expected.zip(result).foreach { case (e, r) => println(s"exp: $e res: $r") } } @@ -43,15 +39,13 @@ object DistanceMetrics { case class Comparison[T](previous: T, current: T, timeDelta: Window) - - def functionBuilder[T](binningFunc: Comparison[T] => Distributions, distanceFunc: Distributions => Double): Comparison[T] => Double = { - c => - val dists = binningFunc(c) - val distance = distanceFunc(dists) - distance + def functionBuilder[T](binningFunc: Comparison[T] => Distributions, + distanceFunc: Distributions => Double): Comparison[T] => Double = { c => + val dists = binningFunc(c) + val distance = distanceFunc(dists) + distance } - def hellingerDistance(p: Array[Double], q: Array[Double]): Double = { val (pProbs, qProbs) = computePDFs(p, q) @@ -110,14 +104,14 @@ object DistanceMetrics { val interval: Double = 1.toDouble / (n - 1.0) def mass(i: Int, eh: Int): Mass = { - def indexMass(i: Int): Double = interval / (percentiles(i) - percentiles(i-1)) + def indexMass(i: Int): Double = interval / (percentiles(i) - percentiles(i - 1)) val isPointMass = eh > 1 val m = (i, eh) match { - case (0, _) => 0.0 // before range - case (x, 0) if x>=n => 0.0 // after range - case (_, e) if e > 1 => (e - 1) * interval // point mass - case (x, 1) if x==n => indexMass(n-1) // exactly at end of range - case (x, _) => indexMass(x) // somewhere in between + case (0, _) => 0.0 // before range + case (x, 0) if x >= n => 0.0 // after range + case (_, e) if e > 1 => (e - 1) * interval // point mass + case (x, 1) if x == n => indexMass(n - 1) // exactly at end of range + case (x, _) => indexMass(x) // somewhere in between } Mass(m, isPointMass) } @@ -137,8 +131,8 @@ object DistanceMetrics { require(p.length == q.length, s"Inputs are of different length ${p.length}, ${q.length}") var i = 0 var result = 0.0 - while(i < p.length) { - val inc = if (p(i)> 0 && q(i) > 0) p(i) * math.log(p(i) / q(i)) else 0 + while (i < p.length) { + val inc = if (p(i) > 0 && q(i) > 0) p(i) * math.log(p(i) / q(i)) else 0 result += inc i += 1 } From 787b66416870b3fc1b558602b4550697d83b17b1 Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Wed, 6 Nov 2024 23:12:52 -0800 Subject: [PATCH 09/37] tile summary distance --- .../online/stats/DistanceMetrics.scala | 192 +++++++++--------- 1 file changed, 91 insertions(+), 101 deletions(-) diff --git a/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala b/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala index a7896efc94..e4820172e4 100644 --- a/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala +++ b/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala @@ -1,7 +1,9 @@ package ai.chronon.online.stats -import ai.chronon.api.Window +import ai.chronon.api.DriftMetric +import ai.chronon.api.TileSummaries import scala.math._ +import scala.util.ScalaJavaConversions.IteratorOps object DistanceMetrics { @@ -12,130 +14,118 @@ object DistanceMetrics { val B = Array(5, 15, 25, 35, 45, 55, 65, 75, 85, 95, 115, 115, 115, 135, 145, 155, 165, 1175, 1205, 1205, 1205).map( _.toDouble) - val jsd = jensenShannonDivergence(A, B) + val jsd = percentileDistance(A, B, DriftMetric.JENSEN_SHANNON) println(f"The Jensen-Shannon Divergence between distributions A and B is: $jsd%.5f") - val psi = populationStabilityIndex(A, B) + val psi = percentileDistance(A, B, DriftMetric.PSI) println(f"The Population Stability Index between distributions A and B is: $psi%.5f") - val hd = hellingerDistance(A, B) + val hd = percentileDistance(A, B, DriftMetric.HELLINGER) println(f"The Hellinger Distance between distributions A and B is: $hd%.5f") - - // format: off - // aligned vertically for easier reasoning - val ptiles = Array( 1, 4, 6,6,6, 8, 9 ) - val breaks = Array(0, 1, 2, 3, 5, 6, 7, 8, 9, 10) - // format: on - - //val interval = 0.25 - val expected = Array(0.0, 1.0 / 3.0, 1.0 / 3.0, (1.0) / (3.0) + (1.0) / (2.0), (1.0) / (2.0), 2.5, 0.5, 1, 0) - - val result = AssignIntervals.on(ptiles = ptiles.map(_.toDouble), breaks = breaks.map(_.toDouble)) - - expected.zip(result).foreach { case (e, r) => println(s"exp: $e res: $r") } - } - // all same size - used for distance computation and for drill down display in front-end - case class Distributions(p: Array[Double], q: Array[Double], bins: Array[String]) - - case class Comparison[T](previous: T, current: T, timeDelta: Window) - - def functionBuilder[T](binningFunc: Comparison[T] => Distributions, - distanceFunc: Distributions => Double): Comparison[T] => Double = { c => - val dists = binningFunc(c) - val distance = distanceFunc(dists) - distance + @inline + private def toArray(l: java.util.List[java.lang.Double]): Array[Double] = { + l.iterator().toScala.map(_.toDouble).toArray } - def hellingerDistance(p: Array[Double], q: Array[Double]): Double = { - val (pProbs, qProbs) = computePDFs(p, q) - - sqrt( - pProbs - .zip(qProbs) - .map { - case (pi, qi) => - pow(sqrt(pi) - sqrt(qi), 2) - } - .sum / 2) + @inline + private def normalizeInplace(arr: Array[Double]): Array[Double] = { + val sum = arr.sum + var i = 0 + while (i < arr.length) { + arr.update(i, arr(i) / sum) + i += 1 + } + arr } - def populationStabilityIndex(p: Array[Double], q: Array[Double]): Double = { - val (pProbs, qProbs) = computePDFs(p, q) - - pProbs - .zip(qProbs) - .map { - case (pi, qi) => - if (pi > 0 && qi > 0) (qi - pi) * log(qi / pi) - else 0.0 // Handle zero probabilities - } - .sum + def distance(a: TileSummaries, b: TileSummaries, metric: DriftMetric): java.lang.Double = { + require(a.isSetPercentiles == b.isSetPercentiles, "Percentiles should be either set or unset together") + require(a.isSetHistogram == b.isSetHistogram, "Histograms should be either set or unset together") + + val isContinuous = a.isSetPercentiles && b.isSetPercentiles + val isCategorical = a.isSetHistogram && b.isSetHistogram + if (isContinuous) + percentileDistance(toArray(a.getPercentiles), toArray(b.getPercentiles), metric) + else if (isCategorical) + categoricalDistance(a.getHistogram, b.getHistogram, metric) + else + null } - def jensenShannonDivergence(p: Array[Double], q: Array[Double]): Double = { - // Step 1: compute probability distributions on the same x-axis - val (pdfP, pdfQ) = computePDFs(p, q) + def percentileDistance(a: Array[Double], b: Array[Double], metric: DriftMetric): Double = { + val breaks = (a ++ b).sorted.distinct + val aProjected = AssignIntervals.on(a, breaks) + val bProjected = AssignIntervals.on(b, breaks) - // Step 2: compute the mixture distribution M - val pdfM = pdfP.zip(pdfQ).map { case (a, b) => 0.5 * (a + b) } + val aNormalized = normalizeInplace(aProjected) + val bNormalized = normalizeInplace(bProjected) - // Step 3: compute divergence - val klAM = klDivergence(pdfP, pdfM) - val klBM = klDivergence(pdfQ, pdfM) + val func = termFunc(metric) - 0.5 * (klAM + klBM) + var i = 0 + var result = 0.0 + while (i < aNormalized.length) { + result += func(aNormalized(i), bNormalized(i)) + i += 1 + } + result } - def computePDFs(p: Array[Double], q: Array[Double]): (Array[Double], Array[Double]) = { - val breakpoints = (p ++ q).distinct.sorted + type Histogram = java.util.Map[String, java.lang.Long] + def categoricalDistance(a: Histogram, b: Histogram, metric: DriftMetric): Double = { + val aIt = a.entrySet().iterator() + var result = 0.0 + val func = termFunc(metric) + while (aIt.hasNext) { + val entry = aIt.next() + val key = entry.getKey + val aVal = entry.getValue.toDouble + val bValOpt = b.get(key) + val bVal = if (bValOpt == null) bValOpt.toDouble else 0.0 + val term = func(aVal, bVal) + result += term + } - val pdfP = computePDF(p, breakpoints).map(_.value) - val pdfQ = computePDF(q, breakpoints).map(_.value) + val bIt = b.entrySet().iterator() + while (bIt.hasNext) { + val entry = bIt.next() + val key = entry.getKey + val bVal = entry.getValue.toDouble + val aValOpt = a.get(key) + if (aValOpt == null) { + result += func(0.0, bVal) + } + } - pdfP -> pdfQ + result } - case class Mass(value: Double, isPointMass: Boolean) - - def computePDF(percentiles: Array[Double], breaks: Array[Double]): Array[Mass] = { - val n = percentiles.length - require(percentiles.length > 2, "Need at-least 3 percentiles to plot a distribution") + @inline + def klDivergenceTerm(a: Double, b: Double): Double = { + if (a > 0 && b > 0) a * math.log(a / b) else 0 + } - val interval: Double = 1.toDouble / (n - 1.0) + @inline + def jsdTerm(a: Double, b: Double): Double = { + val m = (a + b) * 0.5 + (klDivergenceTerm(a, m) + klDivergenceTerm(b, m)) * 0.5 + } - def mass(i: Int, eh: Int): Mass = { - def indexMass(i: Int): Double = interval / (percentiles(i) - percentiles(i - 1)) - val isPointMass = eh > 1 - val m = (i, eh) match { - case (0, _) => 0.0 // before range - case (x, 0) if x >= n => 0.0 // after range - case (_, e) if e > 1 => (e - 1) * interval // point mass - case (x, 1) if x == n => indexMass(n - 1) // exactly at end of range - case (x, _) => indexMass(x) // somewhere in between - } - Mass(m, isPointMass) - } + @inline + def hellingerTerm(a: Double, b: Double): Double = { + pow(sqrt(a) - sqrt(b), 2) * 0.5 + } - var i = 0 - breaks.map { break => - var equalityHits = 0 - while (i < percentiles.length && percentiles(i) <= break) { - if (percentiles(i) == break) equalityHits += 1 - i += 1 - } - mass(i, equalityHits) - } + @inline + def psiTerm(a: Double, b: Double): Double = { + if (a > 0 && b > 0) (b - a) * log(b / a) else 0.0 } - def klDivergence(p: Array[Double], q: Array[Double]): Double = { - require(p.length == q.length, s"Inputs are of different length ${p.length}, ${q.length}") - var i = 0 - var result = 0.0 - while (i < p.length) { - val inc = if (p(i) > 0 && q(i) > 0) p(i) * math.log(p(i) / q(i)) else 0 - result += inc - i += 1 + @inline + def termFunc(d: DriftMetric): (Double, Double) => Double = + d match { + case DriftMetric.PSI => psiTerm + case DriftMetric.HELLINGER => hellingerTerm + case DriftMetric.JENSEN_SHANNON => jsdTerm } - result - } } From e576dadbe8057d327e81675039883e5474c0214b Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Thu, 7 Nov 2024 18:54:19 -0800 Subject: [PATCH 10/37] histogram drift --- .../online/stats/DistanceMetrics.scala | 160 +++++++++++++++--- 1 file changed, 135 insertions(+), 25 deletions(-) diff --git a/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala b/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala index e4820172e4..a586b95aab 100644 --- a/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala +++ b/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala @@ -4,22 +4,90 @@ import ai.chronon.api.TileSummaries import scala.math._ import scala.util.ScalaJavaConversions.IteratorOps +import scala.util.ScalaJavaConversions.JMapOps object DistanceMetrics { // TODO move this to unit test def main(args: Array[String]): Unit = { - val A = Array(0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200).map( - _.toDouble) - val B = Array(5, 15, 25, 35, 45, 55, 65, 75, 85, 95, 115, 115, 115, 135, 145, 155, 165, 1175, 1205, 1205, 1205).map( - _.toDouble) - - val jsd = percentileDistance(A, B, DriftMetric.JENSEN_SHANNON) - println(f"The Jensen-Shannon Divergence between distributions A and B is: $jsd%.5f") - val psi = percentileDistance(A, B, DriftMetric.PSI) - println(f"The Population Stability Index between distributions A and B is: $psi%.5f") - val hd = percentileDistance(A, B, DriftMetric.HELLINGER) - println(f"The Hellinger Distance between distributions A and B is: $hd%.5f") + + def buildPercentiles(mean: Double, variance: Double, breaks: Int = 20): Array[Double] = { + val stdDev = math.sqrt(variance) + + // Create probability points from 0.01 to 0.99 instead of 0 to 1 + val probPoints = (0 to breaks).map { i => + if (i == 0) 0.01 // p1 instead of p0 + else if (i == breaks) 0.99 // p99 instead of p100 + else i.toDouble / breaks + }.toArray + + // Convert probability points to percentiles + probPoints.map { p => + val standardNormalPercentile = math.sqrt(2) * inverseErf(2 * p - 1) + mean + (stdDev * standardNormalPercentile) + } + } + + // Helper function to calculate inverse error function + def inverseErf(x: Double): Double = { + // Approximation of inverse error function + // This is a rational approximation giving a maximum relative error of 3e-7 + val a = 0.147 + val signX = if (x >= 0) 1 else -1 + val absX = math.abs(x) + + val term1 = math.pow(2 / (math.Pi * a) + math.log(1 - absX * absX) / 2, 0.5) + val term2 = math.log(1 - absX * absX) / a + + signX * math.sqrt(term1 - term2) + } + + def compareDistributions(meanA: Double, + varianceA: Double, + meanB: Double, + varianceB: Double, + breaks: Int = 20, + debug: Boolean = false): Unit = { + + val aPercentiles = buildPercentiles(meanA, varianceA, breaks) + val bPercentiles = buildPercentiles(meanB, varianceB, breaks) + + val aHistogram: Histogram = (0 to breaks) + .map { i => + val value = java.lang.Long.valueOf((math.abs(aPercentiles(i)) * 100).toLong) + i.toString -> value + } + .toMap + .toJava + + val bHistogram: Histogram = (0 to breaks) + .map { i => + val value = java.lang.Long.valueOf((math.abs(bPercentiles(i)) * 100).toLong) + i.toString -> value + } + .toMap + .toJava + + val jsd = percentileDistance(aPercentiles, bPercentiles, DriftMetric.JENSEN_SHANNON, debug = debug) + val jsdHist = histogramDistance(aHistogram, bHistogram, DriftMetric.JENSEN_SHANNON) + println(f"The Jensen-Shannon Divergence between distributions A and B is: $jsd%.5f, $jsdHist%.5f") + + val psi = percentileDistance(aPercentiles, bPercentiles, DriftMetric.PSI, debug = debug) + val psiHist = histogramDistance(aHistogram, bHistogram, DriftMetric.PSI) + println(f"The Population Stability Index between distributions A and B is: $psi%.5f, $psiHist%.5f") + + val hd = percentileDistance(aPercentiles, bPercentiles, DriftMetric.HELLINGER, debug = debug) + val hdHist = histogramDistance(aHistogram, bHistogram, DriftMetric.HELLINGER) + println(f"The Hellinger Distance between distributions A and B is: $hd%.5f, $hdHist%.5f") + + println() + } + + compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 105.0, varianceB = 256.0) + compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 205.0, varianceB = 256.0) + compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 305.0, varianceB = 256.0) + compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 105.0, varianceB = 100.0) + } @inline @@ -28,14 +96,16 @@ object DistanceMetrics { } @inline - private def normalizeInplace(arr: Array[Double]): Array[Double] = { + private def normalize(arr: Array[Double]): Array[Double] = { + // TODO-OPTIMIZATION: normalize in place instead if this is a hotspot + val result = Array.ofDim[Double](arr.length) val sum = arr.sum var i = 0 while (i < arr.length) { - arr.update(i, arr(i) / sum) + result.update(i, arr(i) / sum) i += 1 } - arr + result } def distance(a: TileSummaries, b: TileSummaries, metric: DriftMetric): java.lang.Double = { @@ -47,32 +117,69 @@ object DistanceMetrics { if (isContinuous) percentileDistance(toArray(a.getPercentiles), toArray(b.getPercentiles), metric) else if (isCategorical) - categoricalDistance(a.getHistogram, b.getHistogram, metric) + histogramDistance(a.getHistogram, b.getHistogram, metric) else null } - def percentileDistance(a: Array[Double], b: Array[Double], metric: DriftMetric): Double = { + def percentileDistance(a: Array[Double], b: Array[Double], metric: DriftMetric, debug: Boolean = false): Double = { val breaks = (a ++ b).sorted.distinct val aProjected = AssignIntervals.on(a, breaks) val bProjected = AssignIntervals.on(b, breaks) - val aNormalized = normalizeInplace(aProjected) - val bNormalized = normalizeInplace(bProjected) + val aNormalized = normalize(aProjected) + val bNormalized = normalize(bProjected) val func = termFunc(metric) var i = 0 var result = 0.0 + + // debug only, remove before merging + val deltas = Array.ofDim[Double](aNormalized.length) + while (i < aNormalized.length) { - result += func(aNormalized(i), bNormalized(i)) + val ai = aNormalized(i) + val bi = bNormalized(i) + val delta = func(ai, bi) + + // debug only remove before merging + deltas.update(i, delta) + + result += delta i += 1 } + + if (debug) { + def printArr(arr: Array[Double]): String = + arr.map(v => f"$v%.3f").mkString(", ") + println(f""" + |aProjected : ${printArr(aProjected)} + |bProjected : ${printArr(bProjected)} + |aNormalized: ${printArr(aNormalized)} + |bNormalized: ${printArr(bNormalized)} + |deltas : ${printArr(deltas)} + |result : $result%.4f + |""".stripMargin) + } result } + // java map is what thrift produces upon deserialization type Histogram = java.util.Map[String, java.lang.Long] - def categoricalDistance(a: Histogram, b: Histogram, metric: DriftMetric): Double = { + def histogramDistance(a: Histogram, b: Histogram, metric: DriftMetric): Double = { + + @inline def sumValues(h: Histogram): Double = { + var result = 0.0 + val it = h.entrySet().iterator() + while (it.hasNext) { + result += it.next().getValue + } + result + } + val aSum = sumValues(a) + val bSum = sumValues(b) + val aIt = a.entrySet().iterator() var result = 0.0 val func = termFunc(metric) @@ -80,9 +187,9 @@ object DistanceMetrics { val entry = aIt.next() val key = entry.getKey val aVal = entry.getValue.toDouble - val bValOpt = b.get(key) - val bVal = if (bValOpt == null) bValOpt.toDouble else 0.0 - val term = func(aVal, bVal) + val bValOpt: java.lang.Long = b.get(key) + val bVal: Double = if (bValOpt == null) 0.0 else bValOpt.toDouble + val term = func(aVal / aSum, bVal / bSum) result += term } @@ -93,7 +200,8 @@ object DistanceMetrics { val bVal = entry.getValue.toDouble val aValOpt = a.get(key) if (aValOpt == null) { - result += func(0.0, bVal) + val term = func(0.0, bVal / bSum) + result += term } } @@ -118,7 +226,9 @@ object DistanceMetrics { @inline def psiTerm(a: Double, b: Double): Double = { - if (a > 0 && b > 0) (b - a) * log(b / a) else 0.0 + val aFixed = if (a == 0.0) 1e-5 else a + val bFixed = if (b == 0.0) 1e-5 else b + (bFixed - aFixed) * log(bFixed / aFixed) } @inline From 3ca60fa775e171137555b543bee1fb75e5fcee9b Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Thu, 7 Nov 2024 22:31:10 -0800 Subject: [PATCH 11/37] tile drift --- .../online/stats/DistanceMetrics.scala | 157 ++++-------------- .../test/stats/DistanceMetricsTest.scala | 128 ++++++++++++++ 2 files changed, 162 insertions(+), 123 deletions(-) create mode 100644 online/src/test/scala/ai/chronon/online/test/stats/DistanceMetricsTest.scala diff --git a/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala b/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala index a586b95aab..1b12789de1 100644 --- a/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala +++ b/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala @@ -1,127 +1,9 @@ package ai.chronon.online.stats import ai.chronon.api.DriftMetric -import ai.chronon.api.TileSummaries import scala.math._ -import scala.util.ScalaJavaConversions.IteratorOps -import scala.util.ScalaJavaConversions.JMapOps object DistanceMetrics { - - // TODO move this to unit test - def main(args: Array[String]): Unit = { - - def buildPercentiles(mean: Double, variance: Double, breaks: Int = 20): Array[Double] = { - val stdDev = math.sqrt(variance) - - // Create probability points from 0.01 to 0.99 instead of 0 to 1 - val probPoints = (0 to breaks).map { i => - if (i == 0) 0.01 // p1 instead of p0 - else if (i == breaks) 0.99 // p99 instead of p100 - else i.toDouble / breaks - }.toArray - - // Convert probability points to percentiles - probPoints.map { p => - val standardNormalPercentile = math.sqrt(2) * inverseErf(2 * p - 1) - mean + (stdDev * standardNormalPercentile) - } - } - - // Helper function to calculate inverse error function - def inverseErf(x: Double): Double = { - // Approximation of inverse error function - // This is a rational approximation giving a maximum relative error of 3e-7 - val a = 0.147 - val signX = if (x >= 0) 1 else -1 - val absX = math.abs(x) - - val term1 = math.pow(2 / (math.Pi * a) + math.log(1 - absX * absX) / 2, 0.5) - val term2 = math.log(1 - absX * absX) / a - - signX * math.sqrt(term1 - term2) - } - - def compareDistributions(meanA: Double, - varianceA: Double, - meanB: Double, - varianceB: Double, - breaks: Int = 20, - debug: Boolean = false): Unit = { - - val aPercentiles = buildPercentiles(meanA, varianceA, breaks) - val bPercentiles = buildPercentiles(meanB, varianceB, breaks) - - val aHistogram: Histogram = (0 to breaks) - .map { i => - val value = java.lang.Long.valueOf((math.abs(aPercentiles(i)) * 100).toLong) - i.toString -> value - } - .toMap - .toJava - - val bHistogram: Histogram = (0 to breaks) - .map { i => - val value = java.lang.Long.valueOf((math.abs(bPercentiles(i)) * 100).toLong) - i.toString -> value - } - .toMap - .toJava - - val jsd = percentileDistance(aPercentiles, bPercentiles, DriftMetric.JENSEN_SHANNON, debug = debug) - val jsdHist = histogramDistance(aHistogram, bHistogram, DriftMetric.JENSEN_SHANNON) - println(f"The Jensen-Shannon Divergence between distributions A and B is: $jsd%.5f, $jsdHist%.5f") - - val psi = percentileDistance(aPercentiles, bPercentiles, DriftMetric.PSI, debug = debug) - val psiHist = histogramDistance(aHistogram, bHistogram, DriftMetric.PSI) - println(f"The Population Stability Index between distributions A and B is: $psi%.5f, $psiHist%.5f") - - val hd = percentileDistance(aPercentiles, bPercentiles, DriftMetric.HELLINGER, debug = debug) - val hdHist = histogramDistance(aHistogram, bHistogram, DriftMetric.HELLINGER) - println(f"The Hellinger Distance between distributions A and B is: $hd%.5f, $hdHist%.5f") - - println() - } - - compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 105.0, varianceB = 256.0) - compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 205.0, varianceB = 256.0) - compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 305.0, varianceB = 256.0) - compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 105.0, varianceB = 100.0) - - } - - @inline - private def toArray(l: java.util.List[java.lang.Double]): Array[Double] = { - l.iterator().toScala.map(_.toDouble).toArray - } - - @inline - private def normalize(arr: Array[Double]): Array[Double] = { - // TODO-OPTIMIZATION: normalize in place instead if this is a hotspot - val result = Array.ofDim[Double](arr.length) - val sum = arr.sum - var i = 0 - while (i < arr.length) { - result.update(i, arr(i) / sum) - i += 1 - } - result - } - - def distance(a: TileSummaries, b: TileSummaries, metric: DriftMetric): java.lang.Double = { - require(a.isSetPercentiles == b.isSetPercentiles, "Percentiles should be either set or unset together") - require(a.isSetHistogram == b.isSetHistogram, "Histograms should be either set or unset together") - - val isContinuous = a.isSetPercentiles && b.isSetPercentiles - val isCategorical = a.isSetHistogram && b.isSetHistogram - if (isContinuous) - percentileDistance(toArray(a.getPercentiles), toArray(b.getPercentiles), metric) - else if (isCategorical) - histogramDistance(a.getHistogram, b.getHistogram, metric) - else - null - } - def percentileDistance(a: Array[Double], b: Array[Double], metric: DriftMetric, debug: Boolean = false): Double = { val breaks = (a ++ b).sorted.distinct val aProjected = AssignIntervals.on(a, breaks) @@ -209,33 +91,62 @@ object DistanceMetrics { } @inline - def klDivergenceTerm(a: Double, b: Double): Double = { + private def normalize(arr: Array[Double]): Array[Double] = { + // TODO-OPTIMIZATION: normalize in place instead if this is a hotspot + val result = Array.ofDim[Double](arr.length) + val sum = arr.sum + var i = 0 + while (i < arr.length) { + result.update(i, arr(i) / sum) + i += 1 + } + result + } + + @inline + private def klDivergenceTerm(a: Double, b: Double): Double = { if (a > 0 && b > 0) a * math.log(a / b) else 0 } @inline - def jsdTerm(a: Double, b: Double): Double = { + private def jsdTerm(a: Double, b: Double): Double = { val m = (a + b) * 0.5 (klDivergenceTerm(a, m) + klDivergenceTerm(b, m)) * 0.5 } @inline - def hellingerTerm(a: Double, b: Double): Double = { + private def hellingerTerm(a: Double, b: Double): Double = { pow(sqrt(a) - sqrt(b), 2) * 0.5 } @inline - def psiTerm(a: Double, b: Double): Double = { + private def psiTerm(a: Double, b: Double): Double = { val aFixed = if (a == 0.0) 1e-5 else a val bFixed = if (b == 0.0) 1e-5 else b (bFixed - aFixed) * log(bFixed / aFixed) } @inline - def termFunc(d: DriftMetric): (Double, Double) => Double = + private def termFunc(d: DriftMetric): (Double, Double) => Double = d match { case DriftMetric.PSI => psiTerm case DriftMetric.HELLINGER => hellingerTerm case DriftMetric.JENSEN_SHANNON => jsdTerm } + + case class Thresholds(moderate: Double, severe: Double) { + def str(driftScore: Double): String = { + if (driftScore < moderate) "LOW" + else if (driftScore < severe) "MODERATE" + else "SEVERE" + } + } + + @inline + def thresholds(d: DriftMetric): Thresholds = + d match { + case DriftMetric.JENSEN_SHANNON => Thresholds(0.05, 0.15) + case DriftMetric.HELLINGER => Thresholds(0.05, 0.15) + case DriftMetric.PSI => Thresholds(0.1, 0.2) + } } diff --git a/online/src/test/scala/ai/chronon/online/test/stats/DistanceMetricsTest.scala b/online/src/test/scala/ai/chronon/online/test/stats/DistanceMetricsTest.scala new file mode 100644 index 0000000000..5e571ba5d7 --- /dev/null +++ b/online/src/test/scala/ai/chronon/online/test/stats/DistanceMetricsTest.scala @@ -0,0 +1,128 @@ +package ai.chronon.online.test.stats + +import ai.chronon.api.DriftMetric +import ai.chronon.online.stats.DistanceMetrics.histogramDistance +import ai.chronon.online.stats.DistanceMetrics.percentileDistance +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.matchers.should.Matchers + +import scala.util.ScalaJavaConversions.JMapOps + +class DistanceMetricsTest extends AnyFunSuite with Matchers { + + def buildPercentiles(mean: Double, variance: Double, breaks: Int = 20): Array[Double] = { + val stdDev = math.sqrt(variance) + + val probPoints = (0 to breaks).map { i => + if (i == 0) 0.01 + else if (i == breaks) 0.99 + else i.toDouble / breaks + }.toArray + + probPoints.map { p => + val standardNormalPercentile = math.sqrt(2) * inverseErf(2 * p - 1) + mean + (stdDev * standardNormalPercentile) + } + } + + def inverseErf(x: Double): Double = { + val a = 0.147 + val signX = if (x >= 0) 1 else -1 + val absX = math.abs(x) + + val term1 = math.pow(2 / (math.Pi * a) + math.log(1 - absX * absX) / 2, 0.5) + val term2 = math.log(1 - absX * absX) / a + + signX * math.sqrt(term1 - term2) + } + type Histogram = java.util.Map[String, java.lang.Long] + + def compareDistributions(meanA: Double, + varianceA: Double, + meanB: Double, + varianceB: Double, + breaks: Int = 20, + debug: Boolean = false): Map[DriftMetric, (Double, Double)] = { + + val aPercentiles = buildPercentiles(meanA, varianceA, breaks) + val bPercentiles = buildPercentiles(meanB, varianceB, breaks) + + val aHistogram: Histogram = (0 to breaks) + .map { i => + val value = java.lang.Long.valueOf((math.abs(aPercentiles(i)) * 100).toLong) + i.toString -> value + } + .toMap + .toJava + + val bHistogram: Histogram = (0 to breaks) + .map { i => + val value = java.lang.Long.valueOf((math.abs(bPercentiles(i)) * 100).toLong) + i.toString -> value + } + .toMap + .toJava + + def calculateDrift(metric: DriftMetric): (Double, Double) = { + val pDrift = percentileDistance(aPercentiles, bPercentiles, metric, debug = debug) + val histoDrift = histogramDistance(aHistogram, bHistogram, metric) + (pDrift, histoDrift) + } + + Map( + DriftMetric.JENSEN_SHANNON -> calculateDrift(DriftMetric.JENSEN_SHANNON), + DriftMetric.PSI -> calculateDrift(DriftMetric.PSI), + DriftMetric.HELLINGER -> calculateDrift(DriftMetric.HELLINGER) + ) + } + + test("Low drift - similar distributions") { + val drifts = compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 101.0, varianceB = 225.0) + + // JSD assertions + val (jsdPercentile, jsdHisto) = drifts(DriftMetric.JENSEN_SHANNON) + jsdPercentile should be < 0.05 + jsdHisto should be < 0.05 + + // Hellinger assertions + val (hellingerPercentile, hellingerHisto) = drifts(DriftMetric.HELLINGER) + hellingerPercentile should be < 0.05 + hellingerHisto should be < 0.05 + } + + test("Moderate drift - slightly different distributions") { + val drifts = compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 105.0, varianceB = 256.0) + + // JSD assertions + val (jsdPercentile, _) = drifts(DriftMetric.JENSEN_SHANNON) + jsdPercentile should (be >= 0.05 and be <= 0.15) + + // Hellinger assertions + val (hellingerPercentile, _) = drifts(DriftMetric.HELLINGER) + hellingerPercentile should (be >= 0.05 and be <= 0.15) + } + + test("Severe drift - different means") { + val drifts = compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 110.0, varianceB = 225.0) + + // JSD assertions + val (jsdPercentile, _) = drifts(DriftMetric.JENSEN_SHANNON) + jsdPercentile should be > 0.15 + + // Hellinger assertions + val (hellingerPercentile, _) = drifts(DriftMetric.HELLINGER) + hellingerPercentile should be > 0.15 + } + + test("Severe drift - different variances") { + val drifts = compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 105.0, varianceB = 100.0) + + // JSD assertions + val (jsdPercentile, _) = drifts(DriftMetric.JENSEN_SHANNON) + jsdPercentile should be > 0.15 + + // Hellinger assertions + val (hellingerPercentile, _) = drifts(DriftMetric.HELLINGER) + hellingerPercentile should be > 0.15 + } +} From ea115e384d80c32af7fd76bdfe951bebb7530cc2 Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Tue, 12 Nov 2024 05:27:44 -0800 Subject: [PATCH 12/37] test wiring --- hub/app/model/Model.scala | 7 + .../ai/chronon/online/stats/Display.scala | 205 ++++++++++++++++++ .../online/stats/DistanceMetrics.scala | 152 ------------- .../test/stats/DistanceMetricsTest.scala | 128 ----------- 4 files changed, 212 insertions(+), 280 deletions(-) create mode 100644 online/src/main/scala/ai/chronon/online/stats/Display.scala delete mode 100644 online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala delete mode 100644 online/src/test/scala/ai/chronon/online/test/stats/DistanceMetricsTest.scala diff --git a/hub/app/model/Model.scala b/hub/app/model/Model.scala index a498936611..a14ce4d679 100644 --- a/hub/app/model/Model.scala +++ b/hub/app/model/Model.scala @@ -5,6 +5,13 @@ case class GroupBy(name: String, features: Seq[String]) case class Join(name: String, joinFeatures: Seq[String], groupBys: Seq[GroupBy]) case class Model(name: String, join: Join, online: Boolean, production: Boolean, team: String, modelType: String) +// 1.) metadataUpload: join -> map> +// 2.) fetchJoinConf + listColumns: join => list +// 3.) (columns, start, end) -> list + +// 4.) 1:n/fetchTile: tileKey -> TileSummaries +// 5.) 1:n:n/compareTiles: TileSummaries, TileSummaries -> TileDrift +// 6.) Map[column, Seq[tileDrift]] -> TimeSeriesController /** Supported Metric types */ sealed trait MetricType diff --git a/online/src/main/scala/ai/chronon/online/stats/Display.scala b/online/src/main/scala/ai/chronon/online/stats/Display.scala new file mode 100644 index 0000000000..a4c757dc27 --- /dev/null +++ b/online/src/main/scala/ai/chronon/online/stats/Display.scala @@ -0,0 +1,205 @@ +package ai.chronon.online.stats + +import cask._ +import scalatags.Text.all._ +import scalatags.Text.tags2.title + +// generates html / js code to serve a tabbed board on the network port +// boards are static and do not update, used for debugging only +// uses uPlot under the hood +object Display { + // single line inside a chart + case class Series(series: Array[Double], name: String) + // multiple lines in a chart plus the x-axis and a threshold (horizontal dashed line) + case class Chart(seriesList: Array[Series], + x: Array[String], + name: String, + moderateThreshold: Option[Double] = None, + severeThreshold: Option[Double] = None) + + // multiple charts in a section + case class Section(charts: Array[Chart], name: String) + // multiple sections in a tab + case class Tab(sectionList: Array[Section], name: String) + // multiple tabs in a board + case class Board(tabList: Array[Tab], name: String) + + private def generateChartJs(chart: Chart, chartId: String): String = { + val data = chart.seriesList.map(_.series) + val xData = chart.x.map(_.toString) + chart.seriesList.map(_.name) + + val seriesConfig = chart.seriesList.map(s => s"""{ + | label: "${s.name}", + | stroke: "rgb(${scala.util.Random.nextInt(255)}, ${scala.util.Random.nextInt(255)}, ${scala.util.Random.nextInt(255)})" + | + |}""".stripMargin).mkString(",\n") + + val thresholdLines = (chart.moderateThreshold.map(t => s""" + |{ + | label: "Moderate Threshold", + | value: $t, + | stroke: "#ff9800", + | style: [2, 2] + |}""".stripMargin) ++ + chart.severeThreshold.map(t => s""" + |{ + | label: "Severe Threshold", + | value: $t, + | stroke: "#f44336", + | style: [2, 2] + |}""".stripMargin)).mkString(",") + + s""" + |new uPlot({ + | title: "${chart.name}", + | id: "$chartId", + | class: "chart", + | width: 800, + | height: 400, + | scales: { + | x: { + | time: false, + | } + | }, + | series: [ + | {}, + | $seriesConfig + | ], + | axes: [ + | {}, + | { + | label: "Value", + | grid: true, + | } + | ], + | plugins: [ + | { + | hooks: { + | draw: u => { + | ${if (thresholdLines.nonEmpty) + s"""const lines = [$thresholdLines]; + | for (const line of lines) { + | const scale = u.scales.y; + | const y = scale.getPos(line.value); + | + | u.ctx.save(); + | u.ctx.strokeStyle = line.stroke; + | u.ctx.setLineDash(line.style); + | + | u.ctx.beginPath(); + | u.ctx.moveTo(u.bbox.left, y); + | u.ctx.lineTo(u.bbox.left + u.bbox.width, y); + | u.ctx.stroke(); + | + | u.ctx.restore(); + | }""".stripMargin + else ""} + | } + | } + | } + | ] + |}, [${xData.mkString("\"", "\",\"", "\"")}, ${data + .map(_.mkString(",")) + .mkString("[", "],[", "]")}], document.getElementById("$chartId")); + |""".stripMargin + } + + def serve(board: Board, portVal: Int = 9032): Unit = { + + object Server extends cask.MainRoutes { + @get("/") + def index() = { + val page = html( + head( + title(board.name), + script(src := "https://unpkg.com/uplot@1.6.24/dist/uPlot.iife.min.js"), + link(rel := "stylesheet", href := "https://unpkg.com/uplot@1.6.24/dist/uPlot.min.css"), + tag("style")(""" + |body { font-family: Arial, sans-serif; margin: 20px; } + |.tab { display: none; } + |.tab.active { display: block; } + |.tab-button { padding: 10px 20px; margin-right: 5px; cursor: pointer; } + |.tab-button.active { background-color: #ddd; } + |.section { margin: 20px 0; } + |.chart { margin: 20px 0; } + """.stripMargin) + ), + body( + h1(board.name), + div(cls := "tabs")( + board.tabList.map(tab => + button( + cls := "tab-button", + onclick := s"showTab('${tab.name}')", + tab.name + )) + ), + board.tabList.map(tab => + div(cls := "tab", id := tab.name)( + tab.sectionList.map(section => + div(cls := "section")( + h2(section.name), + section.charts.map(chart => + div(cls := "chart")( + div(id := s"${tab.name}-${section.name}-${chart.name}".replaceAll("\\s+", "-")) + )) + )) + )), + script(raw(""" + |function showTab(tabName) { + | document.querySelectorAll('.tab').forEach(tab => { + | tab.style.display = tab.id === tabName ? 'block' : 'none'; + | }); + | document.querySelectorAll('.tab-button').forEach(button => { + | button.classList.toggle('active', button.textContent === tabName); + | }); + |} + | + |// Show first tab by default + |document.querySelector('.tab-button').click(); + """.stripMargin)), + script( + raw( + board.tabList + .flatMap(tab => + tab.sectionList.flatMap(section => + section.charts.map(chart => + generateChartJs(chart, s"${tab.name}-${section.name}-${chart.name}".replaceAll("\\s+", "-"))))) + .mkString("\n") + )) + ) + ) + +// page.render + + cask.Response( + page.render, + headers = Seq("Content-Type" -> "text/html") + ) + } + + override def host: String = "0.0.0.0" + override def port: Int = portVal + + initialize() + } + + Server.main(Array()) + } + + def main(args: Array[String]): Unit = { + val series = Array(Series(Array(1.0, 2.0, 3.0), "Series 1"), Series(Array(2.0, 3.0, 4.0), "Series 2")) + val chart = Chart(series, Array("A", "B", "C"), "Chart 1", Some(2.5), Some(3.5)) + val section = Section(Array(chart), "Section 1") + val tab = Tab(Array(section), "Tab 1") + val board = Board(Array(tab), "Board 1") + + println("serving board at http://localhost:9032/") + serve(board) + // Keep the program running + while (true) { + Thread.sleep(5000) + } + } +} diff --git a/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala b/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala deleted file mode 100644 index 1b12789de1..0000000000 --- a/online/src/main/scala/ai/chronon/online/stats/DistanceMetrics.scala +++ /dev/null @@ -1,152 +0,0 @@ -package ai.chronon.online.stats -import ai.chronon.api.DriftMetric - -import scala.math._ - -object DistanceMetrics { - def percentileDistance(a: Array[Double], b: Array[Double], metric: DriftMetric, debug: Boolean = false): Double = { - val breaks = (a ++ b).sorted.distinct - val aProjected = AssignIntervals.on(a, breaks) - val bProjected = AssignIntervals.on(b, breaks) - - val aNormalized = normalize(aProjected) - val bNormalized = normalize(bProjected) - - val func = termFunc(metric) - - var i = 0 - var result = 0.0 - - // debug only, remove before merging - val deltas = Array.ofDim[Double](aNormalized.length) - - while (i < aNormalized.length) { - val ai = aNormalized(i) - val bi = bNormalized(i) - val delta = func(ai, bi) - - // debug only remove before merging - deltas.update(i, delta) - - result += delta - i += 1 - } - - if (debug) { - def printArr(arr: Array[Double]): String = - arr.map(v => f"$v%.3f").mkString(", ") - println(f""" - |aProjected : ${printArr(aProjected)} - |bProjected : ${printArr(bProjected)} - |aNormalized: ${printArr(aNormalized)} - |bNormalized: ${printArr(bNormalized)} - |deltas : ${printArr(deltas)} - |result : $result%.4f - |""".stripMargin) - } - result - } - - // java map is what thrift produces upon deserialization - type Histogram = java.util.Map[String, java.lang.Long] - def histogramDistance(a: Histogram, b: Histogram, metric: DriftMetric): Double = { - - @inline def sumValues(h: Histogram): Double = { - var result = 0.0 - val it = h.entrySet().iterator() - while (it.hasNext) { - result += it.next().getValue - } - result - } - val aSum = sumValues(a) - val bSum = sumValues(b) - - val aIt = a.entrySet().iterator() - var result = 0.0 - val func = termFunc(metric) - while (aIt.hasNext) { - val entry = aIt.next() - val key = entry.getKey - val aVal = entry.getValue.toDouble - val bValOpt: java.lang.Long = b.get(key) - val bVal: Double = if (bValOpt == null) 0.0 else bValOpt.toDouble - val term = func(aVal / aSum, bVal / bSum) - result += term - } - - val bIt = b.entrySet().iterator() - while (bIt.hasNext) { - val entry = bIt.next() - val key = entry.getKey - val bVal = entry.getValue.toDouble - val aValOpt = a.get(key) - if (aValOpt == null) { - val term = func(0.0, bVal / bSum) - result += term - } - } - - result - } - - @inline - private def normalize(arr: Array[Double]): Array[Double] = { - // TODO-OPTIMIZATION: normalize in place instead if this is a hotspot - val result = Array.ofDim[Double](arr.length) - val sum = arr.sum - var i = 0 - while (i < arr.length) { - result.update(i, arr(i) / sum) - i += 1 - } - result - } - - @inline - private def klDivergenceTerm(a: Double, b: Double): Double = { - if (a > 0 && b > 0) a * math.log(a / b) else 0 - } - - @inline - private def jsdTerm(a: Double, b: Double): Double = { - val m = (a + b) * 0.5 - (klDivergenceTerm(a, m) + klDivergenceTerm(b, m)) * 0.5 - } - - @inline - private def hellingerTerm(a: Double, b: Double): Double = { - pow(sqrt(a) - sqrt(b), 2) * 0.5 - } - - @inline - private def psiTerm(a: Double, b: Double): Double = { - val aFixed = if (a == 0.0) 1e-5 else a - val bFixed = if (b == 0.0) 1e-5 else b - (bFixed - aFixed) * log(bFixed / aFixed) - } - - @inline - private def termFunc(d: DriftMetric): (Double, Double) => Double = - d match { - case DriftMetric.PSI => psiTerm - case DriftMetric.HELLINGER => hellingerTerm - case DriftMetric.JENSEN_SHANNON => jsdTerm - } - - case class Thresholds(moderate: Double, severe: Double) { - def str(driftScore: Double): String = { - if (driftScore < moderate) "LOW" - else if (driftScore < severe) "MODERATE" - else "SEVERE" - } - } - - @inline - def thresholds(d: DriftMetric): Thresholds = - d match { - case DriftMetric.JENSEN_SHANNON => Thresholds(0.05, 0.15) - case DriftMetric.HELLINGER => Thresholds(0.05, 0.15) - case DriftMetric.PSI => Thresholds(0.1, 0.2) - } -} diff --git a/online/src/test/scala/ai/chronon/online/test/stats/DistanceMetricsTest.scala b/online/src/test/scala/ai/chronon/online/test/stats/DistanceMetricsTest.scala deleted file mode 100644 index 5e571ba5d7..0000000000 --- a/online/src/test/scala/ai/chronon/online/test/stats/DistanceMetricsTest.scala +++ /dev/null @@ -1,128 +0,0 @@ -package ai.chronon.online.test.stats - -import ai.chronon.api.DriftMetric -import ai.chronon.online.stats.DistanceMetrics.histogramDistance -import ai.chronon.online.stats.DistanceMetrics.percentileDistance -import org.scalatest.funsuite.AnyFunSuite -import org.scalatest.matchers.should.Matchers - -import scala.util.ScalaJavaConversions.JMapOps - -class DistanceMetricsTest extends AnyFunSuite with Matchers { - - def buildPercentiles(mean: Double, variance: Double, breaks: Int = 20): Array[Double] = { - val stdDev = math.sqrt(variance) - - val probPoints = (0 to breaks).map { i => - if (i == 0) 0.01 - else if (i == breaks) 0.99 - else i.toDouble / breaks - }.toArray - - probPoints.map { p => - val standardNormalPercentile = math.sqrt(2) * inverseErf(2 * p - 1) - mean + (stdDev * standardNormalPercentile) - } - } - - def inverseErf(x: Double): Double = { - val a = 0.147 - val signX = if (x >= 0) 1 else -1 - val absX = math.abs(x) - - val term1 = math.pow(2 / (math.Pi * a) + math.log(1 - absX * absX) / 2, 0.5) - val term2 = math.log(1 - absX * absX) / a - - signX * math.sqrt(term1 - term2) - } - type Histogram = java.util.Map[String, java.lang.Long] - - def compareDistributions(meanA: Double, - varianceA: Double, - meanB: Double, - varianceB: Double, - breaks: Int = 20, - debug: Boolean = false): Map[DriftMetric, (Double, Double)] = { - - val aPercentiles = buildPercentiles(meanA, varianceA, breaks) - val bPercentiles = buildPercentiles(meanB, varianceB, breaks) - - val aHistogram: Histogram = (0 to breaks) - .map { i => - val value = java.lang.Long.valueOf((math.abs(aPercentiles(i)) * 100).toLong) - i.toString -> value - } - .toMap - .toJava - - val bHistogram: Histogram = (0 to breaks) - .map { i => - val value = java.lang.Long.valueOf((math.abs(bPercentiles(i)) * 100).toLong) - i.toString -> value - } - .toMap - .toJava - - def calculateDrift(metric: DriftMetric): (Double, Double) = { - val pDrift = percentileDistance(aPercentiles, bPercentiles, metric, debug = debug) - val histoDrift = histogramDistance(aHistogram, bHistogram, metric) - (pDrift, histoDrift) - } - - Map( - DriftMetric.JENSEN_SHANNON -> calculateDrift(DriftMetric.JENSEN_SHANNON), - DriftMetric.PSI -> calculateDrift(DriftMetric.PSI), - DriftMetric.HELLINGER -> calculateDrift(DriftMetric.HELLINGER) - ) - } - - test("Low drift - similar distributions") { - val drifts = compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 101.0, varianceB = 225.0) - - // JSD assertions - val (jsdPercentile, jsdHisto) = drifts(DriftMetric.JENSEN_SHANNON) - jsdPercentile should be < 0.05 - jsdHisto should be < 0.05 - - // Hellinger assertions - val (hellingerPercentile, hellingerHisto) = drifts(DriftMetric.HELLINGER) - hellingerPercentile should be < 0.05 - hellingerHisto should be < 0.05 - } - - test("Moderate drift - slightly different distributions") { - val drifts = compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 105.0, varianceB = 256.0) - - // JSD assertions - val (jsdPercentile, _) = drifts(DriftMetric.JENSEN_SHANNON) - jsdPercentile should (be >= 0.05 and be <= 0.15) - - // Hellinger assertions - val (hellingerPercentile, _) = drifts(DriftMetric.HELLINGER) - hellingerPercentile should (be >= 0.05 and be <= 0.15) - } - - test("Severe drift - different means") { - val drifts = compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 110.0, varianceB = 225.0) - - // JSD assertions - val (jsdPercentile, _) = drifts(DriftMetric.JENSEN_SHANNON) - jsdPercentile should be > 0.15 - - // Hellinger assertions - val (hellingerPercentile, _) = drifts(DriftMetric.HELLINGER) - hellingerPercentile should be > 0.15 - } - - test("Severe drift - different variances") { - val drifts = compareDistributions(meanA = 100.0, varianceA = 225.0, meanB = 105.0, varianceB = 100.0) - - // JSD assertions - val (jsdPercentile, _) = drifts(DriftMetric.JENSEN_SHANNON) - jsdPercentile should be > 0.15 - - // Hellinger assertions - val (hellingerPercentile, _) = drifts(DriftMetric.HELLINGER) - hellingerPercentile should be > 0.15 - } -} From 1af7b767e50f3a2ad95f5134ca6cd9a5ea7b19a7 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Wed, 20 Nov 2024 15:48:43 -0500 Subject: [PATCH 13/37] Rename DynamoDB store to monitoring model store --- hub/app/controllers/ModelController.scala | 4 ++-- hub/app/controllers/SearchController.scala | 4 ++-- hub/app/module/DynamoDBModule.scala | 16 ---------------- hub/app/module/ModelStoreModule.scala | 16 ++++++++++++++++ ...ingStore.scala => MonitoringModelStore.scala} | 10 +++++----- hub/conf/application.conf | 2 +- hub/test/controllers/ModelControllerSpec.scala | 4 ++-- hub/test/controllers/SearchControllerSpec.scala | 4 ++-- ...Test.scala => MonitoringModelStoreTest.scala} | 6 +++--- 9 files changed, 33 insertions(+), 33 deletions(-) delete mode 100644 hub/app/module/DynamoDBModule.scala create mode 100644 hub/app/module/ModelStoreModule.scala rename hub/app/store/{DynamoDBMonitoringStore.scala => MonitoringModelStore.scala} (94%) rename hub/test/store/{DynamoDBMonitoringStoreTest.scala => MonitoringModelStoreTest.scala} (92%) diff --git a/hub/app/controllers/ModelController.scala b/hub/app/controllers/ModelController.scala index e895c8c27f..40ef41a56c 100644 --- a/hub/app/controllers/ModelController.scala +++ b/hub/app/controllers/ModelController.scala @@ -4,7 +4,7 @@ import io.circe.generic.auto._ import io.circe.syntax._ import model.ListModelResponse import play.api.mvc._ -import store.DynamoDBMonitoringStore +import store.MonitoringModelStore import javax.inject._ @@ -13,7 +13,7 @@ import javax.inject._ */ @Singleton class ModelController @Inject() (val controllerComponents: ControllerComponents, - monitoringStore: DynamoDBMonitoringStore) + monitoringStore: MonitoringModelStore) extends BaseController with Paginate { diff --git a/hub/app/controllers/SearchController.scala b/hub/app/controllers/SearchController.scala index ac6b39110e..cb36e76a62 100644 --- a/hub/app/controllers/SearchController.scala +++ b/hub/app/controllers/SearchController.scala @@ -5,7 +5,7 @@ import io.circe.syntax._ import model.Model import model.SearchModelResponse import play.api.mvc._ -import store.DynamoDBMonitoringStore +import store.MonitoringModelStore import javax.inject._ @@ -13,7 +13,7 @@ import javax.inject._ * Controller to power search related APIs */ class SearchController @Inject() (val controllerComponents: ControllerComponents, - monitoringStore: DynamoDBMonitoringStore) + monitoringStore: MonitoringModelStore) extends BaseController with Paginate { diff --git a/hub/app/module/DynamoDBModule.scala b/hub/app/module/DynamoDBModule.scala deleted file mode 100644 index c9d1ab307e..0000000000 --- a/hub/app/module/DynamoDBModule.scala +++ /dev/null @@ -1,16 +0,0 @@ -package module - -import ai.chronon.integrations.aws.AwsApiImpl -import com.google.inject.AbstractModule -import play.api.Configuration -import play.api.Environment -import store.DynamoDBMonitoringStore - -class DynamoDBModule(environment: Environment, configuration: Configuration) extends AbstractModule { - - override def configure(): Unit = { - val awsApiImpl = new AwsApiImpl(Map.empty) - val dynamoDBMonitoringStore = new DynamoDBMonitoringStore(awsApiImpl) - bind(classOf[DynamoDBMonitoringStore]).toInstance(dynamoDBMonitoringStore) - } -} diff --git a/hub/app/module/ModelStoreModule.scala b/hub/app/module/ModelStoreModule.scala new file mode 100644 index 0000000000..801faeaa77 --- /dev/null +++ b/hub/app/module/ModelStoreModule.scala @@ -0,0 +1,16 @@ +package module + +import ai.chronon.integrations.aws.AwsApiImpl +import com.google.inject.AbstractModule +import play.api.Configuration +import play.api.Environment +import store.MonitoringModelStore + +class ModelStoreModule(environment: Environment, configuration: Configuration) extends AbstractModule { + + override def configure(): Unit = { + val awsApiImpl = new AwsApiImpl(Map.empty) + val dynamoDBMonitoringStore = new MonitoringModelStore(awsApiImpl) + bind(classOf[MonitoringModelStore]).toInstance(dynamoDBMonitoringStore) + } +} diff --git a/hub/app/store/DynamoDBMonitoringStore.scala b/hub/app/store/MonitoringModelStore.scala similarity index 94% rename from hub/app/store/DynamoDBMonitoringStore.scala rename to hub/app/store/MonitoringModelStore.scala index d435f9ad52..bfa1f5e895 100644 --- a/hub/app/store/DynamoDBMonitoringStore.scala +++ b/hub/app/store/MonitoringModelStore.scala @@ -31,10 +31,10 @@ case class LoadedConfs(joins: Seq[api.Join] = Seq.empty, stagingQueries: Seq[api.StagingQuery] = Seq.empty, models: Seq[api.Model] = Seq.empty) -class DynamoDBMonitoringStore(apiImpl: Api) { +class MonitoringModelStore(apiImpl: Api) { - val dynamoDBKVStore: KVStore = apiImpl.genKvStore - implicit val executionContext: ExecutionContext = dynamoDBKVStore.executionContext + val kvStore: KVStore = apiImpl.genKvStore + implicit val executionContext: ExecutionContext = kvStore.executionContext // to help periodically refresh the load config catalog, we wrap this in a TTL cache lazy val configRegistryCache: TTLCache[String, LoadedConfs] = { @@ -59,7 +59,7 @@ class DynamoDBMonitoringStore(apiImpl: Api) { GroupBy(part.groupBy.metaData.name, part.groupBy.valueColumns) } - val outputColumns = thriftJoin.outputColumnsByGroup.values.flatten.toArray + val outputColumns = thriftJoin.ooutputColumnsByGroup.getOrElse("derivations", Array.empty) val join = Join(thriftJoin.metaData.name, outputColumns, groupBys) Option( Model(m.metaData.name, join, m.metaData.online, m.metaData.production, m.metaData.team, m.modelType.name())) @@ -81,7 +81,7 @@ class DynamoDBMonitoringStore(apiImpl: Api) { } val listRequest = ListRequest(MetadataEndPoint.ConfByKeyEndPointName, propsMap) logger.info(s"Triggering list conf lookup with request: $listRequest") - dynamoDBKVStore.list(listRequest).flatMap { response => + kvStore.list(listRequest).flatMap { response => val newLoadedConfs = makeLoadedConfs(response) val newAcc = LoadedConfs( acc.joins ++ newLoadedConfs.joins, diff --git a/hub/conf/application.conf b/hub/conf/application.conf index c3dceef48d..1d6b9996bf 100644 --- a/hub/conf/application.conf +++ b/hub/conf/application.conf @@ -28,4 +28,4 @@ play.filters.cors { } # Add DynamoDB module -play.modules.enabled += "module.DynamoDBModule" +play.modules.enabled += "module.ModelStoreModule" diff --git a/hub/test/controllers/ModelControllerSpec.scala b/hub/test/controllers/ModelControllerSpec.scala index 8e33830093..95b96ca24a 100644 --- a/hub/test/controllers/ModelControllerSpec.scala +++ b/hub/test/controllers/ModelControllerSpec.scala @@ -15,7 +15,7 @@ import play.api.http.Status.OK import play.api.mvc._ import play.api.test.Helpers._ import play.api.test._ -import store.DynamoDBMonitoringStore +import store.MonitoringModelStore class ModelControllerSpec extends PlaySpec with Results with EitherValues { @@ -24,7 +24,7 @@ class ModelControllerSpec extends PlaySpec with Results with EitherValues { // Create a stub ControllerComponents val stubCC: ControllerComponents = stubControllerComponents() // Create a mocked DynDB store - val mockedStore: DynamoDBMonitoringStore = mock(classOf[DynamoDBMonitoringStore]) + val mockedStore: MonitoringModelStore = mock(classOf[MonitoringModelStore]) val controller = new ModelController(stubCC, mockedStore) diff --git a/hub/test/controllers/SearchControllerSpec.scala b/hub/test/controllers/SearchControllerSpec.scala index 95807cfc35..5510ea4010 100644 --- a/hub/test/controllers/SearchControllerSpec.scala +++ b/hub/test/controllers/SearchControllerSpec.scala @@ -14,14 +14,14 @@ import play.api.http.Status.OK import play.api.mvc._ import play.api.test.Helpers._ import play.api.test._ -import store.DynamoDBMonitoringStore +import store.MonitoringModelStore class SearchControllerSpec extends PlaySpec with Results with EitherValues { // Create a stub ControllerComponents val stubCC: ControllerComponents = stubControllerComponents() // Create a mocked DynDB store - val mockedStore: DynamoDBMonitoringStore = mock(classOf[DynamoDBMonitoringStore]) + val mockedStore: MonitoringModelStore = mock(classOf[MonitoringModelStore]) val controller = new SearchController(stubCC, mockedStore) diff --git a/hub/test/store/DynamoDBMonitoringStoreTest.scala b/hub/test/store/MonitoringModelStoreTest.scala similarity index 92% rename from hub/test/store/DynamoDBMonitoringStoreTest.scala rename to hub/test/store/MonitoringModelStoreTest.scala index 6a77915a3c..a2fcb925a8 100644 --- a/hub/test/store/DynamoDBMonitoringStoreTest.scala +++ b/hub/test/store/MonitoringModelStoreTest.scala @@ -21,7 +21,7 @@ import scala.io.Source import scala.util.Success import scala.util.Try -class DynamoDBMonitoringStoreTest extends MockitoSugar with Matchers { +class MonitoringModelStoreTest extends MockitoSugar with Matchers { var api: Api = _ var kvStore: KVStore = _ @@ -41,13 +41,13 @@ class DynamoDBMonitoringStoreTest extends MockitoSugar with Matchers { @Test def monitoringStoreShouldReturnModels(): Unit = { - val dynamoDBMonitoringStore = new DynamoDBMonitoringStore(api) + val dynamoDBMonitoringStore = new MonitoringModelStore(api) when(kvStore.list(any())).thenReturn(generateListResponse()) validateLoadedConfigs(dynamoDBMonitoringStore) } - private def validateLoadedConfigs(dynamoDBMonitoringStore: DynamoDBMonitoringStore): Unit = { + private def validateLoadedConfigs(dynamoDBMonitoringStore: MonitoringModelStore): Unit = { // check that our store has loaded the relevant artifacts dynamoDBMonitoringStore.getConfigRegistry.models.length shouldBe 1 dynamoDBMonitoringStore.getConfigRegistry.groupBys.length shouldBe 2 From 1228488b88aa642abcd21dd5f10e5f5102f27453 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Thu, 21 Nov 2024 20:17:18 -0500 Subject: [PATCH 14/37] First cut wiring up with passing tests --- .../controllers/TimeSeriesController.scala | 253 ++++++++++-------- hub/app/module/DriftStoreModule.scala | 15 ++ hub/app/store/MonitoringModelStore.scala | 2 +- hub/conf/application.conf | 1 + hub/conf/routes | 2 +- .../TimeSeriesControllerSpec.scala | 187 ++++++++++--- .../ai/chronon/online/stats/Display.scala | 205 -------------- .../ai/chronon/online/stats/DriftStore.scala | 6 + .../online/stats/TileDriftCalculator.scala | 2 +- 9 files changed, 313 insertions(+), 360 deletions(-) create mode 100644 hub/app/module/DriftStoreModule.scala delete mode 100644 online/src/main/scala/ai/chronon/online/stats/Display.scala diff --git a/hub/app/controllers/TimeSeriesController.scala b/hub/app/controllers/TimeSeriesController.scala index f3f546d6b0..0f253999a3 100644 --- a/hub/app/controllers/TimeSeriesController.scala +++ b/hub/app/controllers/TimeSeriesController.scala @@ -1,19 +1,23 @@ package controllers -import ai.chronon.api.DriftMetric +import ai.chronon.api.Extensions.WindowOps +import ai.chronon.api.{DriftMetric, TileDriftSeries, TileSummarySeries, TimeUnit, Window} +import ai.chronon.online.stats.DriftStore import io.circe.generic.auto._ import io.circe.syntax._ import model._ import play.api.mvc._ import javax.inject._ +import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ -import scala.util.Random +import scala.util.{Failure, Random, Success} +import scala.jdk.CollectionConverters._ /** * Controller that serves various time series endpoints at the model, join and feature level */ @Singleton -class TimeSeriesController @Inject() (val controllerComponents: ControllerComponents) extends BaseController { +class TimeSeriesController @Inject() (val controllerComponents: ControllerComponents, driftStore: DriftStore)(implicit ec: ExecutionContext) extends BaseController { import TimeSeriesController._ @@ -25,17 +29,6 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon def fetchModel(id: String, startTs: Long, endTs: Long, offset: String, algorithm: String): Action[AnyContent] = doFetchModel(id, startTs, endTs, offset, algorithm) - /** - * Helps retrieve a model time series with the data sliced based on the relevant slice (identified by sliceId) - */ - def fetchModelSlice(id: String, - sliceId: String, - startTs: Long, - endTs: Long, - offset: String, - algorithm: String): Action[AnyContent] = - doFetchModel(id, startTs, endTs, offset, algorithm, Some(sliceId)) - /** * Helps retrieve a time series (drift or skew) for each of the features that are part of a join. Time series is * retrieved between the start and end ts. If the metric type is for drift, the offset is used to compute the @@ -50,28 +43,15 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon metrics: String, offset: Option[String], algorithm: Option[String]): Action[AnyContent] = - doFetchJoin(name, startTs, endTs, metricType, metrics, None, offset, algorithm) - - /** - * Helps retrieve a time series (drift or skew) for each of the features that are part of a join. The data is sliced - * based on the configured slice (looked up by sliceId) - */ - def fetchJoinSlice(name: String, - sliceId: String, - startTs: Long, - endTs: Long, - metricType: String, - metrics: String, - offset: Option[String], - algorithm: Option[String]): Action[AnyContent] = - doFetchJoin(name, startTs, endTs, metricType, metrics, Some(sliceId), offset, algorithm) + doFetchJoin(name, startTs, endTs, metricType, metrics, offset, algorithm) /** * Helps retrieve a time series (drift or skew) for a given feature. Time series is * retrieved between the start and end ts. Choice of granularity (raw, aggregate, percentiles) along with the * metric type (drift / skew) dictates the shape of the returned time series. */ - def fetchFeature(name: String, + def fetchFeature(join: String, + name: String, startTs: Long, endTs: Long, metricType: String, @@ -79,29 +59,13 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon granularity: String, offset: Option[String], algorithm: Option[String]): Action[AnyContent] = - doFetchFeature(name, startTs, endTs, metricType, metrics, None, granularity, offset, algorithm) - - /** - * Helps retrieve a time series (drift or skew) for a given feature. The data is sliced based on the configured slice - * (looked up by sliceId) - */ - def fetchFeatureSlice(name: String, - sliceId: String, - startTs: Long, - endTs: Long, - metricType: String, - metrics: String, - granularity: String, - offset: Option[String], - algorithm: Option[String]): Action[AnyContent] = - doFetchFeature(name, startTs, endTs, metricType, metrics, Some(sliceId), granularity, offset, algorithm) + doFetchFeature(join, name, startTs, endTs, metricType, metrics, granularity, offset, algorithm) private def doFetchModel(id: String, startTs: Long, endTs: Long, offset: String, - algorithm: String, - sliceId: Option[String] = None): Action[AnyContent] = + algorithm: String): Action[AnyContent] = Action { implicit request: Request[AnyContent] => (parseOffset(Some(offset)), parseAlgorithm(Some(algorithm))) match { case (None, _) => BadRequest(s"Unable to parse offset - $offset") @@ -118,18 +82,17 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon endTs: Long, metricType: String, metrics: String, - slice: Option[String], offset: Option[String], algorithm: Option[String]): Action[AnyContent] = - Action { implicit request: Request[AnyContent] => + Action.async { implicit request: Request[AnyContent] => val metricChoice = parseMetricChoice(Some(metricType)) val metricRollup = parseMetricRollup(Some(metrics)) (metricChoice, metricRollup) match { - case (None, _) => BadRequest("Invalid metric choice. Expect drift / skew") - case (_, None) => BadRequest("Invalid metric rollup. Expect null / value") - case (Some(Drift), Some(rollup)) => doFetchJoinDrift(name, startTs, endTs, rollup, slice, offset, algorithm) - case (Some(Skew), Some(rollup)) => doFetchJoinSkew(name, startTs, endTs, rollup, slice) + case (None, _) => Future.successful(BadRequest("Invalid metric choice. Expect drift / skew")) + case (_, None) => Future.successful(BadRequest("Invalid metric rollup. Expect null / value")) + case (Some(Drift), Some(rollup)) => doFetchJoinDrift(name, startTs, endTs, rollup, offset, algorithm) + case (Some(Skew), Some(rollup)) => doFetchJoinSkew(name, startTs, endTs, rollup) } } @@ -137,34 +100,36 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon startTs: Long, endTs: Long, metric: Metric, - sliceId: Option[String], offset: Option[String], - algorithm: Option[String]): Result = { + algorithm: Option[String]): Future[Result] = { (parseOffset(offset), parseAlgorithm(algorithm)) match { - case (None, _) => BadRequest(s"Unable to parse offset - $offset") - case (_, None) => BadRequest("Invalid drift algorithm. Expect JSD, PSI or Hellinger") - case (Some(_), Some(_)) => - // TODO: Use parsedOffset and parsedAlgorithm when ready - val mockGroupBys = generateMockGroupBys(3) - val groupByTimeSeries = mockGroupBys.map { g => - val mockFeatures = generateMockFeatures(g, 10) - val featureTS = mockFeatures.map { - FeatureTimeSeries(_, generateMockTimeSeriesPoints(startTs, endTs)) + case (None, _) => Future.successful(BadRequest(s"Unable to parse offset - $offset")) + case (_, None) => Future.successful(BadRequest("Invalid drift algorithm. Expect JSD, PSI or Hellinger")) + case (Some(o), Some(driftMetric)) => + val window = new Window(o.toMinutes.toInt, TimeUnit.MINUTES) + val maybeDriftSeries = driftStore.getDriftSeries(name, driftMetric, window, startTs, endTs) + maybeDriftSeries match { + case Failure(exception) => Future.successful(InternalServerError(s"Error computing join drift - ${exception.getMessage}")) + case Success(driftSeriesFuture) => driftSeriesFuture.map { + driftSeries => + // pull up a list of drift series objects for all the features in a group + val grpToDriftSeriesList: Map[String, Seq[TileDriftSeries]] = driftSeries.groupBy(_.key.groupName) + val groupByTimeSeries = grpToDriftSeriesList.map { + case (name, featureDriftSeriesInfoSeq) => GroupByTimeSeries(name, featureDriftSeriesInfoSeq.map(series => convertTileDriftSeriesInfoToTimeSeries(series, metric))) + }.toSeq + + val tsData = JoinTimeSeriesResponse(name, groupByTimeSeries) + Ok(tsData.asJson.noSpaces) } - GroupByTimeSeries(g, featureTS) } - - val mockTSData = JoinTimeSeriesResponse(name, groupByTimeSeries) - Ok(mockTSData.asJson.noSpaces) } } private def doFetchJoinSkew(name: String, startTs: Long, endTs: Long, - metric: Metric, - sliceId: Option[String]): Result = { + metric: Metric): Future[Result] = { val mockGroupBys = generateMockGroupBys(3) val groupByTimeSeries = mockGroupBys.map { g => val mockFeatures = generateMockFeatures(g, 10) @@ -176,71 +141,80 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon val mockTSData = JoinTimeSeriesResponse(name, groupByTimeSeries) val json = mockTSData.asJson.noSpaces - Ok(json) + Future.successful(Ok(json)) } - private def doFetchFeature(name: String, + private def doFetchFeature(join: String, + name: String, startTs: Long, endTs: Long, metricType: String, metrics: String, - slice: Option[String], granularity: String, offset: Option[String], algorithm: Option[String]): Action[AnyContent] = - Action { implicit request: Request[AnyContent] => + Action.async { implicit request: Request[AnyContent] => val metricChoice = parseMetricChoice(Some(metricType)) val metricRollup = parseMetricRollup(Some(metrics)) val granularityType = parseGranularity(granularity) (metricChoice, metricRollup, granularityType) match { - case (None, _, _) => BadRequest("Invalid metric choice. Expect drift / skew") - case (_, None, _) => BadRequest("Invalid metric rollup. Expect null / value") - case (_, _, None) => BadRequest("Invalid granularity. Expect raw / percentile / aggregates") + case (None, _, _) => Future.successful(BadRequest("Invalid metric choice. Expect drift / skew")) + case (_, None, _) => Future.successful(BadRequest("Invalid metric rollup. Expect null / value")) + case (_, _, None) => Future.successful(BadRequest("Invalid granularity. Expect raw / percentile / aggregates")) case (Some(Drift), Some(rollup), Some(g)) => - doFetchFeatureDrift(name, startTs, endTs, rollup, slice, g, offset, algorithm) - case (Some(Skew), Some(rollup), Some(g)) => doFetchFeatureSkew(name, startTs, endTs, rollup, slice, g) + doFetchFeatureDrift(join, name, startTs, endTs, rollup, g, offset, algorithm) + case (Some(Skew), Some(rollup), Some(g)) => doFetchFeatureSkew(name, startTs, endTs, rollup, g) } } - private def doFetchFeatureDrift(name: String, + private def doFetchFeatureDrift(join: String, + name: String, startTs: Long, endTs: Long, metric: Metric, - sliceId: Option[String], granularity: Granularity, offset: Option[String], - algorithm: Option[String]): Result = { + algorithm: Option[String]): Future[Result] = { if (granularity == Raw) { - BadRequest("We don't support Raw granularity for drift metric types") + Future.successful(BadRequest("We don't support Raw granularity for drift metric types")) } else { (parseOffset(offset), parseAlgorithm(algorithm)) match { - case (None, _) => BadRequest(s"Unable to parse offset - $offset") - case (_, None) => BadRequest("Invalid drift algorithm. Expect PSI or KL") - case (Some(_), Some(_)) => - // TODO: Use parsedOffset and parsedAlgorithm when ready - val featureTsJson = if (granularity == Aggregates) { - // if feature name ends in an even digit we consider it continuous and generate mock data accordingly - // else we generate mock data for a categorical feature - val featureId = name.split("_").last.toInt - val featureTs = if (featureId % 2 == 0) { - ComparedFeatureTimeSeries(name, - generateMockRawTimeSeriesPoints(startTs, 100), - generateMockRawTimeSeriesPoints(startTs, 100)) - } else { - ComparedFeatureTimeSeries(name, - generateMockCategoricalTimeSeriesPoints(startTs, 5, 1), - generateMockCategoricalTimeSeriesPoints(startTs, 5, 2)) + case (None, _) => Future.successful(BadRequest(s"Unable to parse offset - $offset")) + case (_, None) => Future.successful(BadRequest("Invalid drift algorithm. Expect JSD, PSI or Hellinger")) + case (Some(o), Some(driftMetric)) => + val window = new Window(o.toMinutes.toInt, TimeUnit.MINUTES) + if (granularity == Aggregates) { + val maybeDriftSeries = + driftStore.getDriftSeries(join, driftMetric, window, startTs, endTs, Some(name)) + maybeDriftSeries match { + case Failure(exception) => Future.successful(InternalServerError(s"Error computing feature drift - ${exception.getMessage}")) + case Success(driftSeriesFuture) => driftSeriesFuture.map { + driftSeries => + val featureTs = convertTileDriftSeriesInfoToTimeSeries(driftSeries.head, metric) + Ok(featureTs.asJson.noSpaces) + } } - featureTs.asJson } else { - // - //{new: Array[Double], old: Array[Double], x: Array[String]} - //{old_null_count: Long, new_null_count: long, old_total_count: Long, new_total_count: Long} - - FeatureTimeSeries(name, generateMockTimeSeriesPercentilePoints(startTs, endTs)).asJson + // percentiles + val maybeCurrentSummarySeries = driftStore.getSummarySeries(join, startTs, endTs, Some(name)) + val maybeBaselineSummarySeries = driftStore.getSummarySeries(join, startTs - window.millis, endTs - window.millis, Some(name)) + (maybeCurrentSummarySeries, maybeBaselineSummarySeries) match { + case (Failure(exceptionA), Failure(exceptionB)) => Future.successful(InternalServerError(s"Error computing feature percentiles for current + offset time window.\nCurrent window error: ${exceptionA.getMessage}\nOffset window error: ${exceptionB.getMessage}")) + case (_, Failure(exception)) => Future.successful(InternalServerError(s"Error computing feature percentiles for offset time window - ${exception.getMessage}")) + case (Failure(exception), _) => Future.successful(InternalServerError(s"Error computing feature percentiles for current time window - ${exception.getMessage}")) + case (Success(currentSummarySeriesFuture), Success(baselineSummarySeriesFuture)) => + Future.sequence(Seq(currentSummarySeriesFuture, baselineSummarySeriesFuture)).map { + merged => + val currentSummarySeries = merged.head + val baselineSummarySeries = merged.last + val currentFeatureTs = convertTileSummarySeriesToTimeSeries(currentSummarySeries.head, metric) + val baselineFeatureTs = convertTileSummarySeriesToTimeSeries(baselineSummarySeries.head, metric) + val comparedTsData = ComparedFeatureTimeSeries(name, baselineFeatureTs, currentFeatureTs) + Ok(comparedTsData.asJson.noSpaces) + } + } } - Ok(featureTsJson.noSpaces) } } } @@ -249,10 +223,9 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon startTs: Long, endTs: Long, metric: Metric, - sliceId: Option[String], - granularity: Granularity): Result = { + granularity: Granularity): Future[Result] = { if (granularity == Aggregates) { - BadRequest("We don't support Aggregates granularity for skew metric types") + Future.successful(BadRequest("We don't support Aggregates granularity for skew metric types")) } else { val featureTsJson = if (granularity == Raw) { val featureTs = ComparedFeatureTimeSeries(name, @@ -263,7 +236,53 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon val featuresTs = FeatureTimeSeries(name, generateMockTimeSeriesPercentilePoints(startTs, endTs)) featuresTs.asJson.noSpaces } - Ok(featureTsJson) + Future.successful(Ok(featureTsJson)) + } + } + + private def convertTileDriftSeriesInfoToTimeSeries(tileDriftSeries: TileDriftSeries, metric: Metric): FeatureTimeSeries = { + val lhsList = if (metric == NullMetric) { + tileDriftSeries.nullRatioChangePercentSeries.asScala + } else { + // check if we have a numeric / categorical feature. If the percentile drift series has non-null doubles + // then we have a numeric feature at hand + val isNumeric = tileDriftSeries.percentileDriftSeries.asScala != null && tileDriftSeries.percentileDriftSeries.asScala.exists(_ != null) + if (isNumeric) tileDriftSeries.percentileDriftSeries.asScala + else tileDriftSeries.histogramDriftSeries.asScala + } + val points = lhsList.zip(tileDriftSeries.timestamps.asScala).map { + case (v, ts) => TimeSeriesPoint(v, ts) + } + + FeatureTimeSeries(tileDriftSeries.getKey.getColumn, points) + } + + private def convertTileSummarySeriesToTimeSeries(summarySeries: TileSummarySeries, metric: Metric): Seq[TimeSeriesPoint] = { + if (metric == NullMetric) { + summarySeries.nullCount.asScala.zip(summarySeries.timestamps.asScala).map { + case (nullCount, ts) => TimeSeriesPoint(0, ts, nullValue = Some(nullCount.intValue())) + } + } else { + // check if we have a numeric / categorical feature. If the percentile drift series has non-null doubles + // then we have a numeric feature at hand + val isNumeric = summarySeries.percentiles.asScala != null && summarySeries.percentiles.asScala.exists(_ != null) + if (isNumeric) { + summarySeries.percentiles.asScala.zip(summarySeries.timestamps.asScala).flatMap { + case (percentiles, ts) => + DriftStore.percentileLabels.zip(percentiles.asScala).map { + case (l, value) => TimeSeriesPoint(value, ts, Some(l)) + } + } + } + else { + summarySeries.timestamps.asScala.zipWithIndex.flatMap { + case (ts, idx) => + summarySeries.histogram.asScala.map { + case (label, values) => + TimeSeriesPoint(values.get(idx).toDouble, ts, Some(label)) + } + } + } } } } @@ -281,13 +300,11 @@ object TimeSeriesController { } def parseAlgorithm(algorithm: Option[String]): Option[DriftMetric] = { - algorithm.map { - _.toLowerCase match { - case "psi" => DriftMetric.PSI - case "hellinger" => DriftMetric.HELLINGER - case "jsd" => DriftMetric.JENSEN_SHANNON - case _ => throw new IllegalArgumentException("Invalid drift algorithm. Pick one of PSI, Hellinger or JSD") - } + algorithm.map(_.toLowerCase) match { + case Some("psi") => Some(DriftMetric.PSI) + case Some("hellinger") => Some(DriftMetric.HELLINGER) + case Some("jsd") => Some(DriftMetric.JENSEN_SHANNON) + case _ => None } } diff --git a/hub/app/module/DriftStoreModule.scala b/hub/app/module/DriftStoreModule.scala new file mode 100644 index 0000000000..ee831e5cad --- /dev/null +++ b/hub/app/module/DriftStoreModule.scala @@ -0,0 +1,15 @@ +package module + +import ai.chronon.integrations.aws.AwsApiImpl +import ai.chronon.online.stats.DriftStore +import com.google.inject.AbstractModule +import play.api.{Configuration, Environment} + +class DriftStoreModule(environment: Environment, configuration: Configuration) extends AbstractModule { + + override def configure(): Unit = { + val awsApiImpl = new AwsApiImpl(Map.empty) + val driftStore = new DriftStore(awsApiImpl.genKvStore) + bind(classOf[DriftStore]).toInstance(driftStore) + } +} diff --git a/hub/app/store/MonitoringModelStore.scala b/hub/app/store/MonitoringModelStore.scala index bfa1f5e895..d4d67d8f4b 100644 --- a/hub/app/store/MonitoringModelStore.scala +++ b/hub/app/store/MonitoringModelStore.scala @@ -59,7 +59,7 @@ class MonitoringModelStore(apiImpl: Api) { GroupBy(part.groupBy.metaData.name, part.groupBy.valueColumns) } - val outputColumns = thriftJoin.ooutputColumnsByGroup.getOrElse("derivations", Array.empty) + val outputColumns = thriftJoin.outputColumnsByGroup.getOrElse("derivations", Array.empty) val join = Join(thriftJoin.metaData.name, outputColumns, groupBys) Option( Model(m.metaData.name, join, m.metaData.online, m.metaData.production, m.metaData.team, m.modelType.name())) diff --git a/hub/conf/application.conf b/hub/conf/application.conf index 1d6b9996bf..5696edf2d2 100644 --- a/hub/conf/application.conf +++ b/hub/conf/application.conf @@ -29,3 +29,4 @@ play.filters.cors { # Add DynamoDB module play.modules.enabled += "module.ModelStoreModule" +play.modules.enabled += "module.DriftStoreModule" diff --git a/hub/conf/routes b/hub/conf/routes index 7f686530ba..7c88e44bb6 100644 --- a/hub/conf/routes +++ b/hub/conf/routes @@ -12,7 +12,7 @@ GET /api/v1/join/:name/timeseries controllers.TimeSeriesC # join -> seq(feature) # when metricType == "drift" - will return time series list of drift values -GET /api/v1/feature/:name/timeseries controllers.TimeSeriesController.fetchFeature(name: String, startTs: Long, endTs: Long, metricType: String, metrics: String, granularity: String, offset: Option[String], algorithm: Option[String]) +GET /api/v1/join/:join/feature/:name/timeseries controllers.TimeSeriesController.fetchFeature(join: String, name: String, startTs: Long, endTs: Long, metricType: String, metrics: String, granularity: String, offset: Option[String], algorithm: Option[String]) # TODO - move the core flow to fine-grained end-points diff --git a/hub/test/controllers/TimeSeriesControllerSpec.scala b/hub/test/controllers/TimeSeriesControllerSpec.scala index 2d431e0a24..8fb07cb511 100644 --- a/hub/test/controllers/TimeSeriesControllerSpec.scala +++ b/hub/test/controllers/TimeSeriesControllerSpec.scala @@ -1,9 +1,13 @@ package controllers +import ai.chronon.api.{TileDriftSeries, TileSeriesKey, TileSummarySeries} +import ai.chronon.online.stats.DriftStore import io.circe._ import io.circe.generic.auto._ import io.circe.parser._ import model._ +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito.{mock, when} import org.scalatest.EitherValues import org.scalatestplus.play._ import play.api.http.Status.BAD_REQUEST @@ -13,14 +17,24 @@ import play.api.test.Helpers._ import play.api.test._ import java.util.concurrent.TimeUnit +import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration +import scala.util.{Failure, Success, Try} +import java.lang.{Double => JDouble} +import java.lang.{Long => JLong} +import scala.jdk.CollectionConverters._ class TimeSeriesControllerSpec extends PlaySpec with Results with EitherValues { // Create a stub ControllerComponents val stubCC: ControllerComponents = stubControllerComponents() - val controller = new TimeSeriesController(stubCC) + implicit val ec: ExecutionContext = ExecutionContext.global + + // Create a mocked drift store + val mockedStore: DriftStore = mock(classOf[DriftStore]) + val controller = new TimeSeriesController(stubCC, mockedStore) + val mockCategories: Seq[String] = Seq("a", "b", "c") "TimeSeriesController's model ts lookup" should { @@ -78,11 +92,26 @@ class TimeSeriesControllerSpec extends PlaySpec with Results with EitherValues { status(invalid1) mustBe BAD_REQUEST } - "send valid results on a correctly formed model ts drift lookup request" in { + "send 5xx on failed drift store lookup request" in { + when(mockedStore.getDriftSeries(any(), any(), any(), any(), any(), any())).thenReturn(Failure(new IllegalArgumentException("Some internal error"))) + val startTs = 1725926400000L // 09/10/2024 00:00 UTC val endTs = 1726106400000L // 09/12/2024 02:00 UTC val result = controller.fetchJoin("my_join", startTs, endTs, "drift", "null", Some("10h"), Some("psi")).apply(FakeRequest()) + + status(result) mustBe INTERNAL_SERVER_ERROR + } + + "send valid results on a correctly formed model ts drift lookup request" in { + val startTs = 1725926400000L // 09/10/2024 00:00 UTC + val endTs = 1726106400000L // 09/12/2024 02:00 UTC + + val mockedDriftStoreResponse = generateDriftSeries(startTs, endTs, "my_join", 2, 3) + when(mockedStore.getDriftSeries(any(), any(), any(), any(), any(), any())).thenReturn(mockedDriftStoreResponse) + + val result = + controller.fetchJoin("my_join", startTs, endTs, "drift", "value", Some("10h"), Some("psi")).apply(FakeRequest()) status(result) mustBe OK val bodyText = contentAsString(result) val modelTSResponse: Either[Error, JoinTimeSeriesResponse] = decode[JoinTimeSeriesResponse](bodyText) @@ -123,32 +152,32 @@ class TimeSeriesControllerSpec extends PlaySpec with Results with EitherValues { "send 400 on an invalid metric choice" in { val invalid = - controller.fetchFeature("my_feature", 123L, 456L, "meow", "null", "raw", None, None).apply(FakeRequest()) + controller.fetchFeature("my_join", "my_feature", 123L, 456L, "meow", "null", "raw", None, None).apply(FakeRequest()) status(invalid) mustBe BAD_REQUEST } "send 400 on an invalid metric rollup" in { val invalid = - controller.fetchFeature("my_feature", 123L, 456L, "drift", "woof", "raw", None, None).apply(FakeRequest()) + controller.fetchFeature("my_join", "my_feature", 123L, 456L, "drift", "woof", "raw", None, None).apply(FakeRequest()) status(invalid) mustBe BAD_REQUEST } "send 400 on an invalid granularity" in { val invalid = - controller.fetchFeature("my_feature", 123L, 456L, "drift", "null", "woof", None, None).apply(FakeRequest()) + controller.fetchFeature("my_join", "my_feature", 123L, 456L, "drift", "null", "woof", None, None).apply(FakeRequest()) status(invalid) mustBe BAD_REQUEST } "send 400 on an invalid time offset for drift metric" in { val invalid1 = controller - .fetchFeature("my_feature", 123L, 456L, "drift", "null", "aggregates", Some("Xh"), Some("psi")) + .fetchFeature("my_join", "my_feature", 123L, 456L, "drift", "null", "aggregates", Some("Xh"), Some("psi")) .apply(FakeRequest()) status(invalid1) mustBe BAD_REQUEST val invalid2 = controller - .fetchFeature("my_feature", 123L, 456L, "drift", "null", "aggregates", Some("-1h"), Some("psi")) + .fetchFeature("my_join", "my_feature", 123L, 456L, "drift", "null", "aggregates", Some("-1h"), Some("psi")) .apply(FakeRequest()) status(invalid2) mustBe BAD_REQUEST } @@ -156,7 +185,7 @@ class TimeSeriesControllerSpec extends PlaySpec with Results with EitherValues { "send 400 on an invalid algorithm for drift metric" in { val invalid1 = controller - .fetchFeature("my_feature", 123L, 456L, "drift", "null", "aggregates", Some("10h"), Some("meow")) + .fetchFeature("my_join", "my_feature", 123L, 456L, "drift", "null", "aggregates", Some("10h"), Some("meow")) .apply(FakeRequest()) status(invalid1) mustBe BAD_REQUEST } @@ -164,7 +193,7 @@ class TimeSeriesControllerSpec extends PlaySpec with Results with EitherValues { "send 400 on an invalid granularity for drift metric" in { val invalid1 = controller - .fetchFeature("my_feature", 123L, 456L, "drift", "null", "raw", Some("10h"), Some("psi")) + .fetchFeature("my_join", "my_feature", 123L, 456L, "drift", "null", "raw", Some("10h"), Some("psi")) .apply(FakeRequest()) status(invalid1) mustBe BAD_REQUEST } @@ -172,74 +201,101 @@ class TimeSeriesControllerSpec extends PlaySpec with Results with EitherValues { "send 400 on an invalid granularity for skew metric" in { val invalid1 = controller - .fetchFeature("my_feature", 123L, 456L, "skew", "null", "aggregates", Some("10h"), Some("psi")) + .fetchFeature("my_join", "my_feature", 123L, 456L, "skew", "null", "aggregates", Some("10h"), Some("psi")) .apply(FakeRequest()) status(invalid1) mustBe BAD_REQUEST } - "send valid results on a correctly formed numeric feature ts aggregate drift lookup request" in { + "send valid results on a correctly formed feature ts aggregate drift lookup request" in { val startTs = 1725926400000L // 09/10/2024 00:00 UTC val endTs = 1726106400000L // 09/12/2024 02:00 UTC + + val mockedDriftStoreResponse = generateDriftSeries(startTs, endTs, "my_join", 1, 1) + when(mockedStore.getDriftSeries(any(), any(), any(), any(), any(), any())).thenReturn(mockedDriftStoreResponse) + val result = controller - .fetchFeature("my_feature_0", startTs, endTs, "drift", "null", "aggregates", Some("10h"), Some("psi")) + .fetchFeature("my_join", "my_feature_0", startTs, endTs, "drift", "null", "aggregates", Some("10h"), Some("psi")) .apply(FakeRequest()) status(result) mustBe OK val bodyText = contentAsString(result) - val featureTSResponse: Either[Error, ComparedFeatureTimeSeries] = decode[ComparedFeatureTimeSeries](bodyText) + val featureTSResponse: Either[Error, FeatureTimeSeries] = decode[FeatureTimeSeries](bodyText) featureTSResponse.isRight mustBe true val response = featureTSResponse.right.value response.feature mustBe "my_feature_0" - response.current.length mustBe response.baseline.length - response.current.zip(response.baseline).foreach { - case (current, baseline) => - current.ts mustBe baseline.ts - } + val expectedLength = expectedHours(startTs, endTs) + response.points.length mustBe expectedLength } - "send valid results on a correctly formed categorical feature ts aggregate drift lookup request" in { + "send valid results on a correctly formed numeric feature ts percentile drift lookup request" in { val startTs = 1725926400000L // 09/10/2024 00:00 UTC val endTs = 1726106400000L // 09/12/2024 02:00 UTC + + val mockedSummarySeriesResponseA = generateSummarySeries(startTs, endTs, "my_join", "my_groupby", "my_feature", ValuesMetric, true) + val offset = Duration.apply(7, TimeUnit.DAYS) + val mockedSummarySeriesResponseB = + generateSummarySeries(startTs - offset.toMillis, endTs - offset.toMillis, "my_join", "my_groupby", "my_feature", ValuesMetric, true) + when(mockedStore.getSummarySeries(any(), any(), any(), any())).thenReturn(mockedSummarySeriesResponseA, mockedSummarySeriesResponseB) + val result = controller - .fetchFeature("my_feature_1", startTs, endTs, "drift", "null", "aggregates", Some("10h"), Some("psi")) + .fetchFeature("my_join", "my_feature", startTs, endTs, "drift", "value", "percentile", Some("10h"), Some("psi")) .apply(FakeRequest()) status(result) mustBe OK val bodyText = contentAsString(result) val featureTSResponse: Either[Error, ComparedFeatureTimeSeries] = decode[ComparedFeatureTimeSeries](bodyText) featureTSResponse.isRight mustBe true val response = featureTSResponse.right.value - response.feature mustBe "my_feature_1" - response.current.map(_.ts).toSet mustBe response.baseline.map(_.ts).toSet - response.current.foreach(_.label.isEmpty mustBe false) - response.baseline.foreach(_.label.isEmpty mustBe false) + response.feature mustBe "my_feature" + response.current.length mustBe response.baseline.length + response.current.zip(response.baseline).foreach { + case (current, baseline) => + (current.ts - baseline.ts) mustBe offset.toMillis + } + + // expect one entry per percentile for each time series point + val expectedLength = DriftStore.percentileLabels.length * expectedHours(startTs, endTs) + response.current.length mustBe expectedLength + response.baseline.length mustBe expectedLength } - "send valid results on a correctly formed feature ts percentile drift lookup request" in { + "send valid results on a correctly formed categorical feature ts percentile drift lookup request" in { val startTs = 1725926400000L // 09/10/2024 00:00 UTC val endTs = 1726106400000L // 09/12/2024 02:00 UTC + + val mockedSummarySeriesResponseA = generateSummarySeries(startTs, endTs, "my_join", "my_groupby", "my_feature", ValuesMetric, false) + val offset = Duration.apply(7, TimeUnit.DAYS) + val mockedSummarySeriesResponseB = + generateSummarySeries(startTs - offset.toMillis, endTs - offset.toMillis, "my_join", "my_groupby", "my_feature", ValuesMetric, false) + when(mockedStore.getSummarySeries(any(), any(), any(), any())).thenReturn(mockedSummarySeriesResponseA, mockedSummarySeriesResponseB) + val result = controller - .fetchFeature("my_feature", startTs, endTs, "drift", "null", "percentile", Some("10h"), Some("psi")) + .fetchFeature("my_join", "my_feature", startTs, endTs, "drift", "value", "percentile", Some("10h"), Some("psi")) .apply(FakeRequest()) status(result) mustBe OK val bodyText = contentAsString(result) - val featureTSResponse: Either[Error, FeatureTimeSeries] = decode[FeatureTimeSeries](bodyText) + val featureTSResponse: Either[Error, ComparedFeatureTimeSeries] = decode[ComparedFeatureTimeSeries](bodyText) featureTSResponse.isRight mustBe true val response = featureTSResponse.right.value response.feature mustBe "my_feature" - response.points.nonEmpty mustBe true - - // expect one entry per percentile for each time series point - val expectedLength = TimeSeriesController.mockGeneratedPercentiles.length * expectedHours(startTs, endTs) - response.points.length mustBe expectedLength + response.current.length mustBe response.baseline.length + // expect one entry per category for each time series point + val expectedLength = mockCategories.length * expectedHours(startTs, endTs) + response.current.length mustBe expectedLength + response.current.zip(response.baseline).foreach { + case (current, baseline) => + (current.ts - baseline.ts) mustBe offset.toMillis + current.label.isEmpty mustBe false + baseline.label.isEmpty mustBe false + } } "send valid results on a correctly formed feature ts raw skew lookup request" in { val startTs = 1725926400000L // 09/10/2024 00:00 UTC val endTs = 1726106400000L // 09/12/2024 02:00 UTC val result = - controller.fetchFeature("my_feature", startTs, endTs, "skew", "null", "raw", None, None).apply(FakeRequest()) + controller.fetchFeature("my_join", "my_feature", startTs, endTs, "skew", "null", "raw", None, None).apply(FakeRequest()) status(result) mustBe OK val bodyText = contentAsString(result) val featureTSResponse: Either[Error, ComparedFeatureTimeSeries] = @@ -259,7 +315,7 @@ class TimeSeriesControllerSpec extends PlaySpec with Results with EitherValues { val endTs = 1726106400000L // 09/12/2024 02:00 UTC val result = controller - .fetchFeature("my_feature", startTs, endTs, "skew", "null", "percentile", None, None) + .fetchFeature("my_join", "my_feature", startTs, endTs, "skew", "null", "percentile", None, None) .apply(FakeRequest()) status(result) mustBe OK val bodyText = contentAsString(result) @@ -278,4 +334,67 @@ class TimeSeriesControllerSpec extends PlaySpec with Results with EitherValues { private def expectedHours(startTs: Long, endTs: Long): Long = { Duration(endTs - startTs, TimeUnit.MILLISECONDS).toHours } + + private def generateDriftSeries(startTs: Long, endTs: Long, join: String, numGroups: Int, numFeaturesPerGroup: Int): Try[Future[Seq[TileDriftSeries]]] = { + val result = for { + group <- 0 until numGroups + feature <- 0 until numFeaturesPerGroup + } yield { + val name = s"my_group_$group" + val featureName = s"my_feature_$feature" + val tileKey = new TileSeriesKey() + tileKey.setNodeName(join) + tileKey.setGroupName(name) + tileKey.setColumn(featureName) + + val tileDriftSeries = new TileDriftSeries() + tileDriftSeries.setKey(tileKey) + + val timestamps = (startTs until endTs by (Duration(1, TimeUnit.HOURS).toMillis)).toList.map(JLong.valueOf(_)).asJava + // if feature name ends in an even digit we consider it continuous and generate mock data accordingly + // else we generate mock data for a categorical feature + val isNumeric = if (feature % 2 == 0) true else false + val percentileDrifts = if (isNumeric) List.fill(timestamps.size())(JDouble.valueOf(0.12)).asJava else List.fill[JDouble](timestamps.size())(null).asJava + val histogramDrifts = if (isNumeric) List.fill[JDouble](timestamps.size())(null).asJava else List.fill(timestamps.size())(JDouble.valueOf(0.23)).asJava + val nullRationChangePercents = List.fill(timestamps.size())(JDouble.valueOf(0.25)).asJava + tileDriftSeries.setTimestamps(timestamps) + tileDriftSeries.setPercentileDriftSeries(percentileDrifts) + tileDriftSeries.setNullRatioChangePercentSeries(nullRationChangePercents) + tileDriftSeries.setHistogramDriftSeries(histogramDrifts) + } + Success(Future.successful(result)) + } + + private def generateSummarySeries(startTs: Long, endTs: Long, join: String, groupBy: String, featureName: String, metric: Metric, isNumeric: Boolean): Try[Future[Seq[TileSummarySeries]]] = { + val tileKey = new TileSeriesKey() + tileKey.setNodeName(join) + tileKey.setGroupName(groupBy) + tileKey.setNodeName(join) + tileKey.setColumn(featureName) + + val timestamps = (startTs until endTs by (Duration(1, TimeUnit.HOURS).toMillis)).toList.map(JLong.valueOf(_)) + val tileSummarySeries = new TileSummarySeries() + tileSummarySeries.setKey(tileKey) + tileSummarySeries.setTimestamps(timestamps.asJava) + + if (metric == NullMetric) { + tileSummarySeries.setNullCount(List.fill(timestamps.length)(JLong.valueOf(1)).asJava) + } else { + if (isNumeric) { + val percentileList = timestamps.map { + _ => + List.fill(DriftStore.percentileLabels.length)(JDouble.valueOf(0.12)).asJava + }.asJava + tileSummarySeries.setPercentiles(percentileList) + } else { + val histogramMap = mockCategories.map { + category => + category -> List.fill(timestamps.length)(JLong.valueOf(1)).asJava + }.toMap.asJava + tileSummarySeries.setHistogram(histogramMap) + } + } + + Success(Future.successful(Seq(tileSummarySeries))) + } } diff --git a/online/src/main/scala/ai/chronon/online/stats/Display.scala b/online/src/main/scala/ai/chronon/online/stats/Display.scala deleted file mode 100644 index a4c757dc27..0000000000 --- a/online/src/main/scala/ai/chronon/online/stats/Display.scala +++ /dev/null @@ -1,205 +0,0 @@ -package ai.chronon.online.stats - -import cask._ -import scalatags.Text.all._ -import scalatags.Text.tags2.title - -// generates html / js code to serve a tabbed board on the network port -// boards are static and do not update, used for debugging only -// uses uPlot under the hood -object Display { - // single line inside a chart - case class Series(series: Array[Double], name: String) - // multiple lines in a chart plus the x-axis and a threshold (horizontal dashed line) - case class Chart(seriesList: Array[Series], - x: Array[String], - name: String, - moderateThreshold: Option[Double] = None, - severeThreshold: Option[Double] = None) - - // multiple charts in a section - case class Section(charts: Array[Chart], name: String) - // multiple sections in a tab - case class Tab(sectionList: Array[Section], name: String) - // multiple tabs in a board - case class Board(tabList: Array[Tab], name: String) - - private def generateChartJs(chart: Chart, chartId: String): String = { - val data = chart.seriesList.map(_.series) - val xData = chart.x.map(_.toString) - chart.seriesList.map(_.name) - - val seriesConfig = chart.seriesList.map(s => s"""{ - | label: "${s.name}", - | stroke: "rgb(${scala.util.Random.nextInt(255)}, ${scala.util.Random.nextInt(255)}, ${scala.util.Random.nextInt(255)})" - | - |}""".stripMargin).mkString(",\n") - - val thresholdLines = (chart.moderateThreshold.map(t => s""" - |{ - | label: "Moderate Threshold", - | value: $t, - | stroke: "#ff9800", - | style: [2, 2] - |}""".stripMargin) ++ - chart.severeThreshold.map(t => s""" - |{ - | label: "Severe Threshold", - | value: $t, - | stroke: "#f44336", - | style: [2, 2] - |}""".stripMargin)).mkString(",") - - s""" - |new uPlot({ - | title: "${chart.name}", - | id: "$chartId", - | class: "chart", - | width: 800, - | height: 400, - | scales: { - | x: { - | time: false, - | } - | }, - | series: [ - | {}, - | $seriesConfig - | ], - | axes: [ - | {}, - | { - | label: "Value", - | grid: true, - | } - | ], - | plugins: [ - | { - | hooks: { - | draw: u => { - | ${if (thresholdLines.nonEmpty) - s"""const lines = [$thresholdLines]; - | for (const line of lines) { - | const scale = u.scales.y; - | const y = scale.getPos(line.value); - | - | u.ctx.save(); - | u.ctx.strokeStyle = line.stroke; - | u.ctx.setLineDash(line.style); - | - | u.ctx.beginPath(); - | u.ctx.moveTo(u.bbox.left, y); - | u.ctx.lineTo(u.bbox.left + u.bbox.width, y); - | u.ctx.stroke(); - | - | u.ctx.restore(); - | }""".stripMargin - else ""} - | } - | } - | } - | ] - |}, [${xData.mkString("\"", "\",\"", "\"")}, ${data - .map(_.mkString(",")) - .mkString("[", "],[", "]")}], document.getElementById("$chartId")); - |""".stripMargin - } - - def serve(board: Board, portVal: Int = 9032): Unit = { - - object Server extends cask.MainRoutes { - @get("/") - def index() = { - val page = html( - head( - title(board.name), - script(src := "https://unpkg.com/uplot@1.6.24/dist/uPlot.iife.min.js"), - link(rel := "stylesheet", href := "https://unpkg.com/uplot@1.6.24/dist/uPlot.min.css"), - tag("style")(""" - |body { font-family: Arial, sans-serif; margin: 20px; } - |.tab { display: none; } - |.tab.active { display: block; } - |.tab-button { padding: 10px 20px; margin-right: 5px; cursor: pointer; } - |.tab-button.active { background-color: #ddd; } - |.section { margin: 20px 0; } - |.chart { margin: 20px 0; } - """.stripMargin) - ), - body( - h1(board.name), - div(cls := "tabs")( - board.tabList.map(tab => - button( - cls := "tab-button", - onclick := s"showTab('${tab.name}')", - tab.name - )) - ), - board.tabList.map(tab => - div(cls := "tab", id := tab.name)( - tab.sectionList.map(section => - div(cls := "section")( - h2(section.name), - section.charts.map(chart => - div(cls := "chart")( - div(id := s"${tab.name}-${section.name}-${chart.name}".replaceAll("\\s+", "-")) - )) - )) - )), - script(raw(""" - |function showTab(tabName) { - | document.querySelectorAll('.tab').forEach(tab => { - | tab.style.display = tab.id === tabName ? 'block' : 'none'; - | }); - | document.querySelectorAll('.tab-button').forEach(button => { - | button.classList.toggle('active', button.textContent === tabName); - | }); - |} - | - |// Show first tab by default - |document.querySelector('.tab-button').click(); - """.stripMargin)), - script( - raw( - board.tabList - .flatMap(tab => - tab.sectionList.flatMap(section => - section.charts.map(chart => - generateChartJs(chart, s"${tab.name}-${section.name}-${chart.name}".replaceAll("\\s+", "-"))))) - .mkString("\n") - )) - ) - ) - -// page.render - - cask.Response( - page.render, - headers = Seq("Content-Type" -> "text/html") - ) - } - - override def host: String = "0.0.0.0" - override def port: Int = portVal - - initialize() - } - - Server.main(Array()) - } - - def main(args: Array[String]): Unit = { - val series = Array(Series(Array(1.0, 2.0, 3.0), "Series 1"), Series(Array(2.0, 3.0, 4.0), "Series 2")) - val chart = Chart(series, Array("A", "B", "C"), "Chart 1", Some(2.5), Some(3.5)) - val section = Section(Array(chart), "Section 1") - val tab = Tab(Array(section), "Tab 1") - val board = Board(Array(tab), "Board 1") - - println("serving board at http://localhost:9032/") - serve(board) - // Keep the program running - while (true) { - Thread.sleep(5000) - } - } -} diff --git a/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala b/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala index eea086e883..44c83e37c1 100644 --- a/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala +++ b/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala @@ -203,4 +203,10 @@ object DriftStore { def compactSerializer: SerializableSerializer = new SerializableSerializer(new TBinaryProtocol.Factory()) def compactDeserializer: TDeserializer = new TDeserializer(new TBinaryProtocol.Factory()) + + // todo - drop this hard-coded list in favor of a well known list or exposing as part of summaries + val percentileLabels: Seq[String] = Seq("p0", "p5", "p10", "p15", "p20", + "p25", "p30", "p35", "p40", "p45", + "p50", "p55", "p60", "p65", "p70", + "p75", "p80", "p85", "p90", "p95", "p100") } diff --git a/online/src/main/scala/ai/chronon/online/stats/TileDriftCalculator.scala b/online/src/main/scala/ai/chronon/online/stats/TileDriftCalculator.scala index 47518f555a..ef22b4cf1a 100644 --- a/online/src/main/scala/ai/chronon/online/stats/TileDriftCalculator.scala +++ b/online/src/main/scala/ai/chronon/online/stats/TileDriftCalculator.scala @@ -81,7 +81,7 @@ object TileDriftCalculator { result } - // for each summary with ts >= startMs, use spec.lookBack to find the previous summary and calculate dirft + // for each summary with ts >= startMs, use spec.lookBack to find the previous summary and calculate drift // we do this by first creating a map of summaries by timestamp def toTileDrifts(summariesWithTimestamps: Array[(TileSummary, Long)], metric: DriftMetric, From 86b1cfd85f767aba1aa22924de25252dfeb1e2c2 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Thu, 21 Nov 2024 20:31:35 -0500 Subject: [PATCH 15/37] Rip out mock data generation and corresponding endpoints --- .../controllers/TimeSeriesController.scala | 116 +----------------- hub/conf/routes | 4 +- .../TimeSeriesControllerSpec.scala | 108 ++-------------- 3 files changed, 16 insertions(+), 212 deletions(-) diff --git a/hub/app/controllers/TimeSeriesController.scala b/hub/app/controllers/TimeSeriesController.scala index 0f253999a3..625fcd729e 100644 --- a/hub/app/controllers/TimeSeriesController.scala +++ b/hub/app/controllers/TimeSeriesController.scala @@ -21,14 +21,6 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon import TimeSeriesController._ - /** - * Helps retrieve a model performance drift time series. Time series is retrieved between the start and end ts. - * The offset is used to compute the distribution to compare against (we compare current time range with the same - * sized time range starting offset time period prior). - */ - def fetchModel(id: String, startTs: Long, endTs: Long, offset: String, algorithm: String): Action[AnyContent] = - doFetchModel(id, startTs, endTs, offset, algorithm) - /** * Helps retrieve a time series (drift or skew) for each of the features that are part of a join. Time series is * retrieved between the start and end ts. If the metric type is for drift, the offset is used to compute the @@ -61,22 +53,6 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon algorithm: Option[String]): Action[AnyContent] = doFetchFeature(join, name, startTs, endTs, metricType, metrics, granularity, offset, algorithm) - private def doFetchModel(id: String, - startTs: Long, - endTs: Long, - offset: String, - algorithm: String): Action[AnyContent] = - Action { implicit request: Request[AnyContent] => - (parseOffset(Some(offset)), parseAlgorithm(Some(algorithm))) match { - case (None, _) => BadRequest(s"Unable to parse offset - $offset") - case (_, None) => BadRequest("Invalid drift algorithm. Expect PSI or KL") - case (Some(_), Some(_)) => - // TODO: Use parsedOffset and parsedAlgorithm when ready - val mockTSData = ModelTimeSeriesResponse(id, generateMockTimeSeriesPoints(startTs, endTs)) - Ok(mockTSData.asJson.noSpaces) - } - } - private def doFetchJoin(name: String, startTs: Long, endTs: Long, @@ -89,10 +65,9 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon val metricRollup = parseMetricRollup(Some(metrics)) (metricChoice, metricRollup) match { - case (None, _) => Future.successful(BadRequest("Invalid metric choice. Expect drift / skew")) + case (None, _) => Future.successful(BadRequest("Invalid metric choice. Expect drift")) case (_, None) => Future.successful(BadRequest("Invalid metric rollup. Expect null / value")) case (Some(Drift), Some(rollup)) => doFetchJoinDrift(name, startTs, endTs, rollup, offset, algorithm) - case (Some(Skew), Some(rollup)) => doFetchJoinSkew(name, startTs, endTs, rollup) } } @@ -126,24 +101,6 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon } } - private def doFetchJoinSkew(name: String, - startTs: Long, - endTs: Long, - metric: Metric): Future[Result] = { - val mockGroupBys = generateMockGroupBys(3) - val groupByTimeSeries = mockGroupBys.map { g => - val mockFeatures = generateMockFeatures(g, 10) - val featureTS = mockFeatures.map { - FeatureTimeSeries(_, generateMockTimeSeriesPoints(startTs, endTs)) - } - GroupByTimeSeries(g, featureTS) - } - - val mockTSData = JoinTimeSeriesResponse(name, groupByTimeSeries) - val json = mockTSData.asJson.noSpaces - Future.successful(Ok(json)) - } - private def doFetchFeature(join: String, name: String, startTs: Long, @@ -159,12 +116,11 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon val granularityType = parseGranularity(granularity) (metricChoice, metricRollup, granularityType) match { - case (None, _, _) => Future.successful(BadRequest("Invalid metric choice. Expect drift / skew")) + case (None, _, _) => Future.successful(BadRequest("Invalid metric choice. Expect drift")) case (_, None, _) => Future.successful(BadRequest("Invalid metric rollup. Expect null / value")) case (_, _, None) => Future.successful(BadRequest("Invalid granularity. Expect raw / percentile / aggregates")) case (Some(Drift), Some(rollup), Some(g)) => doFetchFeatureDrift(join, name, startTs, endTs, rollup, g, offset, algorithm) - case (Some(Skew), Some(rollup), Some(g)) => doFetchFeatureSkew(name, startTs, endTs, rollup, g) } } @@ -219,27 +175,6 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon } } - private def doFetchFeatureSkew(name: String, - startTs: Long, - endTs: Long, - metric: Metric, - granularity: Granularity): Future[Result] = { - if (granularity == Aggregates) { - Future.successful(BadRequest("We don't support Aggregates granularity for skew metric types")) - } else { - val featureTsJson = if (granularity == Raw) { - val featureTs = ComparedFeatureTimeSeries(name, - generateMockRawTimeSeriesPoints(startTs, 100), - generateMockRawTimeSeriesPoints(startTs, 100)) - featureTs.asJson.noSpaces - } else { - val featuresTs = FeatureTimeSeries(name, generateMockTimeSeriesPercentilePoints(startTs, endTs)) - featuresTs.asJson.noSpaces - } - Future.successful(Ok(featureTsJson)) - } - } - private def convertTileDriftSeriesInfoToTimeSeries(tileDriftSeries: TileDriftSeries, metric: Metric): FeatureTimeSeries = { val lhsList = if (metric == NullMetric) { tileDriftSeries.nullRatioChangePercentSeries.asScala @@ -308,11 +243,12 @@ object TimeSeriesController { } } + // We currently only support drift def parseMetricChoice(metricType: Option[String]): Option[MetricType] = { metricType.map(_.toLowerCase) match { case Some("drift") => Some(Drift) - case Some("skew") => Some(Skew) - case Some("ooc") => Some(Skew) +// case Some("skew") => Some(Skew) +// case Some("ooc") => Some(Skew) case _ => None } } @@ -333,46 +269,4 @@ object TimeSeriesController { case _ => None } } - - // !!!!! Mock generation code !!!!! // - - val mockGeneratedPercentiles: Seq[String] = - Seq("p0", "p10", "p20", "p30", "p40", "p50", "p60", "p70", "p75", "p80", "p90", "p95", "p99", "p100") - - // temporarily serve up mock data while we wait on hooking up our KV store layer + drift calculation - private def generateMockTimeSeriesPoints(startTs: Long, endTs: Long): Seq[TimeSeriesPoint] = { - val random = new Random(1000) - (startTs until endTs by (1.hours.toMillis)).map(ts => TimeSeriesPoint(random.nextDouble(), ts)) - } - - private def generateMockRawTimeSeriesPoints(timestamp: Long, count: Int): Seq[TimeSeriesPoint] = { - val random = new Random(1000) - (0 until count).map(_ => TimeSeriesPoint(random.nextDouble(), timestamp)) - } - - private def generateMockCategoricalTimeSeriesPoints(timestamp: Long, - categoryCount: Int, - nullCategoryCount: Int): Seq[TimeSeriesPoint] = { - val random = new Random(1000) - val catTSPoints = (0 until categoryCount).map(i => TimeSeriesPoint(random.nextInt(1000), timestamp, Some(s"A_$i"))) - val nullCatTSPoints = (0 until nullCategoryCount).map(i => - TimeSeriesPoint(random.nextDouble(), timestamp, Some(s"A_{$i + $categoryCount}"), Some(random.nextInt(10)))) - catTSPoints ++ nullCatTSPoints - } - - private def generateMockTimeSeriesPercentilePoints(startTs: Long, endTs: Long): Seq[TimeSeriesPoint] = { - val random = new Random(1000) - (startTs until endTs by (1.hours.toMillis)).flatMap { ts => - mockGeneratedPercentiles.zipWithIndex.map { - case (p, _) => TimeSeriesPoint(random.nextDouble(), ts, Some(p)) - } - } - } - - private def generateMockGroupBys(numGroupBys: Int): Seq[String] = - (1 to numGroupBys).map(i => s"my_groupby_$i") - - private def generateMockFeatures(groupBy: String, featuresPerGroupBy: Int): Seq[String] = - (1 to featuresPerGroupBy).map(i => s"$groupBy.my_feature_$i") - } diff --git a/hub/conf/routes b/hub/conf/routes index 7c88e44bb6..896fb7655c 100644 --- a/hub/conf/routes +++ b/hub/conf/routes @@ -3,8 +3,8 @@ GET /api/v1/ping controllers.Application GET /api/v1/models controllers.ModelController.list(offset: Option[Int], limit: Option[Int]) GET /api/v1/search controllers.SearchController.search(term: String, offset: Option[Int], limit: Option[Int]) -# model prediction & model drift -GET /api/v1/model/:id/timeseries controllers.TimeSeriesController.fetchModel(id: String, startTs: Long, endTs: Long, offset: String, algorithm: String) +# model prediction & model drift - this is TBD at the moment +# GET /api/v1/model/:id/timeseries controllers.TimeSeriesController.fetchModel(id: String, startTs: Long, endTs: Long, offset: String, algorithm: String) # all timeseries of a given join id # when metricType == "drift" - will return time series list of drift values diff --git a/hub/test/controllers/TimeSeriesControllerSpec.scala b/hub/test/controllers/TimeSeriesControllerSpec.scala index 8fb07cb511..77d1f16806 100644 --- a/hub/test/controllers/TimeSeriesControllerSpec.scala +++ b/hub/test/controllers/TimeSeriesControllerSpec.scala @@ -36,41 +36,16 @@ class TimeSeriesControllerSpec extends PlaySpec with Results with EitherValues { val controller = new TimeSeriesController(stubCC, mockedStore) val mockCategories: Seq[String] = Seq("a", "b", "c") - "TimeSeriesController's model ts lookup" should { + "TimeSeriesController's join ts lookup" should { - "send 400 on an invalid time offset" in { - val invalid1 = controller.fetchModel("id-123", 123L, 456L, "Xh", "psi").apply(FakeRequest()) + "send 400 on an invalid metric choice" in { + val invalid1 = controller.fetchJoin("my_join", 123L, 456L, "meow", "null", None, None).apply(FakeRequest()) status(invalid1) mustBe BAD_REQUEST - val invalid2 = controller.fetchModel("id-123", 123L, 456L, "-10h", "psi").apply(FakeRequest()) + val invalid2 = controller.fetchJoin("my_join", 123L, 456L, "skew", "null", None, None).apply(FakeRequest()) status(invalid2) mustBe BAD_REQUEST } - "send 400 on an invalid algorithm" in { - val invalid1 = controller.fetchModel("id-123", 123L, 456L, "10h", "meow").apply(FakeRequest()) - status(invalid1) mustBe BAD_REQUEST - } - - "send valid results on a correctly formed model ts request" in { - val startTs = 1725926400000L // 09/10/2024 00:00 UTC - val endTs = 1726106400000L // 09/12/2024 02:00 UTC - val result = controller.fetchModel("id-123", startTs, endTs, "10h", "psi").apply(FakeRequest()) - status(result) mustBe OK - val bodyText = contentAsString(result) - val modelTSResponse: Either[Error, ModelTimeSeriesResponse] = decode[ModelTimeSeriesResponse](bodyText) - modelTSResponse.isRight mustBe true - val items = modelTSResponse.right.value.items - items.length mustBe (Duration(endTs, TimeUnit.MILLISECONDS) - Duration(startTs, TimeUnit.MILLISECONDS)).toHours - } - } - - "TimeSeriesController's join ts lookup" should { - - "send 400 on an invalid metric choice" in { - val invalid = controller.fetchJoin("my_join", 123L, 456L, "meow", "null", None, None).apply(FakeRequest()) - status(invalid) mustBe BAD_REQUEST - } - "send 400 on an invalid metric rollup" in { val invalid = controller.fetchJoin("my_join", 123L, 456L, "drift", "woof", None, None).apply(FakeRequest()) status(invalid) mustBe BAD_REQUEST @@ -126,34 +101,16 @@ class TimeSeriesControllerSpec extends PlaySpec with Results with EitherValues { grpByTs.items.foreach(featureTs => featureTs.points.length mustBe expectedLength) } } - - "send valid results on a correctly formed model ts skew lookup request" in { - val startTs = 1725926400000L // 09/10/2024 00:00 UTC - val endTs = 1726106400000L // 09/12/2024 02:00 UTC - val result = - controller.fetchJoin("my_join", startTs, endTs, "skew", "null", None, None).apply(FakeRequest()) - status(result) mustBe OK - val bodyText = contentAsString(result) - val modelTSResponse: Either[Error, JoinTimeSeriesResponse] = decode[JoinTimeSeriesResponse](bodyText) - modelTSResponse.isRight mustBe true - val response = modelTSResponse.right.value - response.name mustBe "my_join" - response.items.nonEmpty mustBe true - - val expectedLength = expectedHours(startTs, endTs) - response.items.foreach { grpByTs => - grpByTs.items.isEmpty mustBe false - grpByTs.items.foreach(featureTs => featureTs.points.length mustBe expectedLength) - } - } } "TimeSeriesController's feature ts lookup" should { "send 400 on an invalid metric choice" in { - val invalid = - controller.fetchFeature("my_join", "my_feature", 123L, 456L, "meow", "null", "raw", None, None).apply(FakeRequest()) - status(invalid) mustBe BAD_REQUEST + val invalid1 = controller.fetchFeature("my_join", "my_feature", 123L, 456L, "meow", "null", "raw", None, None).apply(FakeRequest()) + status(invalid1) mustBe BAD_REQUEST + + val invalid2 = controller.fetchFeature("my_join", "my_feature", 123L, 456L, "skew", "null", "raw", None, None).apply(FakeRequest()) + status(invalid2) mustBe BAD_REQUEST } "send 400 on an invalid metric rollup" in { @@ -198,14 +155,6 @@ class TimeSeriesControllerSpec extends PlaySpec with Results with EitherValues { status(invalid1) mustBe BAD_REQUEST } - "send 400 on an invalid granularity for skew metric" in { - val invalid1 = - controller - .fetchFeature("my_join", "my_feature", 123L, 456L, "skew", "null", "aggregates", Some("10h"), Some("psi")) - .apply(FakeRequest()) - status(invalid1) mustBe BAD_REQUEST - } - "send valid results on a correctly formed feature ts aggregate drift lookup request" in { val startTs = 1725926400000L // 09/10/2024 00:00 UTC val endTs = 1726106400000L // 09/12/2024 02:00 UTC @@ -290,45 +239,6 @@ class TimeSeriesControllerSpec extends PlaySpec with Results with EitherValues { baseline.label.isEmpty mustBe false } } - - "send valid results on a correctly formed feature ts raw skew lookup request" in { - val startTs = 1725926400000L // 09/10/2024 00:00 UTC - val endTs = 1726106400000L // 09/12/2024 02:00 UTC - val result = - controller.fetchFeature("my_join", "my_feature", startTs, endTs, "skew", "null", "raw", None, None).apply(FakeRequest()) - status(result) mustBe OK - val bodyText = contentAsString(result) - val featureTSResponse: Either[Error, ComparedFeatureTimeSeries] = - decode[ComparedFeatureTimeSeries](bodyText) - featureTSResponse.isRight mustBe true - val response = featureTSResponse.right.value - response.feature mustBe "my_feature" - response.baseline.nonEmpty mustBe true - response.baseline.length mustBe response.current.length - // we expect a skew distribution at a fixed time stamp - response.baseline.foreach(p => p.ts mustBe startTs) - response.current.foreach(p => p.ts mustBe startTs) - } - - "send valid results on a correctly formed feature ts percentile skew lookup request" in { - val startTs = 1725926400000L // 09/10/2024 00:00 UTC - val endTs = 1726106400000L // 09/12/2024 02:00 UTC - val result = - controller - .fetchFeature("my_join", "my_feature", startTs, endTs, "skew", "null", "percentile", None, None) - .apply(FakeRequest()) - status(result) mustBe OK - val bodyText = contentAsString(result) - val featureTSResponse: Either[Error, FeatureTimeSeries] = decode[FeatureTimeSeries](bodyText) - featureTSResponse.isRight mustBe true - val response = featureTSResponse.right.value - response.feature mustBe "my_feature" - response.points.nonEmpty mustBe true - - // expect one entry per percentile for each time series point - val expectedLength = TimeSeriesController.mockGeneratedPercentiles.length * expectedHours(startTs, endTs) - response.points.length mustBe expectedLength - } } private def expectedHours(startTs: Long, endTs: Long): Long = { From 3657bb7f51668a86c4240a175d0e51a9ab4f330c Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Thu, 21 Nov 2024 20:54:06 -0500 Subject: [PATCH 16/37] Add joins endpoints and switch search to use joins --- hub/app/controllers/JoinController.scala | 42 ++++++++++++ hub/app/controllers/Paginate.scala | 4 +- hub/app/controllers/SearchController.scala | 17 +++-- hub/app/model/Model.scala | 5 +- hub/app/store/MonitoringModelStore.scala | 11 ++++ hub/conf/routes | 5 +- hub/test/controllers/JoinControllerSpec.scala | 65 +++++++++++++++++++ .../controllers/SearchControllerSpec.scala | 25 ++++--- 8 files changed, 150 insertions(+), 24 deletions(-) create mode 100644 hub/app/controllers/JoinController.scala create mode 100644 hub/test/controllers/JoinControllerSpec.scala diff --git a/hub/app/controllers/JoinController.scala b/hub/app/controllers/JoinController.scala new file mode 100644 index 0000000000..81fd73a970 --- /dev/null +++ b/hub/app/controllers/JoinController.scala @@ -0,0 +1,42 @@ +package controllers + +import io.circe.generic.auto._ +import io.circe.syntax._ +import model.ListJoinResponse +import play.api.mvc._ +import store.MonitoringModelStore + +import javax.inject._ + +/** + * Controller for the Zipline Join entities + */ +@Singleton +class JoinController @Inject()(val controllerComponents: ControllerComponents, + monitoringStore: MonitoringModelStore) + extends BaseController + with Paginate { + + /** + * Powers the /api/v1/joins endpoint. Returns a list of models + * @param offset - For pagination. We skip over offset entries before returning results + * @param limit - Number of elements to return + */ + def list(offset: Option[Int], limit: Option[Int]): Action[AnyContent] = + Action { implicit request: Request[AnyContent] => + // Default values if the parameters are not provided + val offsetValue = offset.getOrElse(defaultOffset) + val limitValue = limit.map(l => math.min(l, maxLimit)).getOrElse(defaultLimit) + + if (offsetValue < 0) { + BadRequest("Invalid offset - expect a positive number") + } else if (limitValue < 0) { + BadRequest("Invalid limit - expect a positive number") + } else { + val joins = monitoringStore.getJoins + val paginatedResults = paginateResults(joins, offsetValue, limitValue) + val json = ListJoinResponse(offsetValue, paginatedResults).asJson.noSpaces + Ok(json) + } + } +} diff --git a/hub/app/controllers/Paginate.scala b/hub/app/controllers/Paginate.scala index d77060cded..86a4eec3e5 100644 --- a/hub/app/controllers/Paginate.scala +++ b/hub/app/controllers/Paginate.scala @@ -1,13 +1,11 @@ package controllers -import model.Model - trait Paginate { val defaultOffset = 0 val defaultLimit = 10 val maxLimit = 100 - def paginateResults(results: Seq[Model], offset: Int, limit: Int): Seq[Model] = { + def paginateResults[T](results: Seq[T], offset: Int, limit: Int): Seq[T] = { results.slice(offset, offset + limit) } } diff --git a/hub/app/controllers/SearchController.scala b/hub/app/controllers/SearchController.scala index cb36e76a62..a6bd1a8ead 100644 --- a/hub/app/controllers/SearchController.scala +++ b/hub/app/controllers/SearchController.scala @@ -2,8 +2,7 @@ package controllers import io.circe.generic.auto._ import io.circe.syntax._ -import model.Model -import model.SearchModelResponse +import model.{Join, SearchJoinResponse} import play.api.mvc._ import store.MonitoringModelStore @@ -18,8 +17,8 @@ class SearchController @Inject() (val controllerComponents: ControllerComponents with Paginate { /** - * Powers the /api/v1/search endpoint. Returns a list of models - * @param term - Search term to search for (currently we only support searching model names) + * Powers the /api/v1/search endpoint. Returns a list of joins + * @param term - Search term to search for (currently we only support searching join names) * @param offset - For pagination. We skip over offset entries before returning results * @param limit - Number of elements to return */ @@ -36,14 +35,14 @@ class SearchController @Inject() (val controllerComponents: ControllerComponents } else { val searchResults = searchRegistry(term) val paginatedResults = paginateResults(searchResults, offsetValue, limitValue) - val json = SearchModelResponse(offsetValue, paginatedResults).asJson.noSpaces + val json = SearchJoinResponse(offsetValue, paginatedResults).asJson.noSpaces Ok(json) } } - // a trivial search where we check the model name for similarity with the search term - private def searchRegistry(term: String): Seq[Model] = { - val models = monitoringStore.getModels - models.filter(m => m.name.contains(term)) + // a trivial search where we check the join name for similarity with the search term + private def searchRegistry(term: String): Seq[Join] = { + val joins = monitoringStore.getJoins + joins.filter(j => j.name.contains(term)) } } diff --git a/hub/app/model/Model.scala b/hub/app/model/Model.scala index a14ce4d679..f3e2d9f54c 100644 --- a/hub/app/model/Model.scala +++ b/hub/app/model/Model.scala @@ -56,9 +56,10 @@ case class FeatureTimeSeries(feature: String, points: Seq[TimeSeriesPoint]) case class ComparedFeatureTimeSeries(feature: String, baseline: Seq[TimeSeriesPoint], current: Seq[TimeSeriesPoint]) case class GroupByTimeSeries(name: String, items: Seq[FeatureTimeSeries]) -// Currently search only covers models -case class SearchModelResponse(offset: Int, items: Seq[Model]) +// Currently search only covers joins case class ListModelResponse(offset: Int, items: Seq[Model]) +case class SearchJoinResponse(offset: Int, items: Seq[Join]) +case class ListJoinResponse(offset: Int, items: Seq[Join]) case class ModelTimeSeriesResponse(id: String, items: Seq[TimeSeriesPoint]) case class JoinTimeSeriesResponse(name: String, items: Seq[GroupByTimeSeries]) diff --git a/hub/app/store/MonitoringModelStore.scala b/hub/app/store/MonitoringModelStore.scala index d4d67d8f4b..9dd11d280c 100644 --- a/hub/app/store/MonitoringModelStore.scala +++ b/hub/app/store/MonitoringModelStore.scala @@ -69,6 +69,17 @@ class MonitoringModelStore(apiImpl: Api) { } } + def getJoins: Seq[Join] = { + configRegistryCache("default").joins.map { thriftJoin => + val groupBys = thriftJoin.joinParts.asScala.map { part => + GroupBy(part.groupBy.metaData.name, part.groupBy.valueColumns) + } + + val outputColumns = thriftJoin.outputColumnsByGroup.getOrElse("derivations", Array.empty) + Join(thriftJoin.metaData.name, outputColumns, groupBys) + } + } + val logger: Logger = Logger(this.getClass) val defaultListLookupLimit: Int = 100 diff --git a/hub/conf/routes b/hub/conf/routes index 896fb7655c..5d24c3bd52 100644 --- a/hub/conf/routes +++ b/hub/conf/routes @@ -1,10 +1,11 @@ # Backend APIs GET /api/v1/ping controllers.ApplicationController.ping() GET /api/v1/models controllers.ModelController.list(offset: Option[Int], limit: Option[Int]) +GET /api/v1/joins controllers.JoinController.list(offset: Option[Int], limit: Option[Int]) GET /api/v1/search controllers.SearchController.search(term: String, offset: Option[Int], limit: Option[Int]) # model prediction & model drift - this is TBD at the moment -# GET /api/v1/model/:id/timeseries controllers.TimeSeriesController.fetchModel(id: String, startTs: Long, endTs: Long, offset: String, algorithm: String) +# GET /api/v1/model/:id/timeseries controllers.TimeSeriesController.fetchModel(id: String, startTs: Long, endTs: Long, offset: String, algorithm: String) # all timeseries of a given join id # when metricType == "drift" - will return time series list of drift values @@ -12,7 +13,7 @@ GET /api/v1/join/:name/timeseries controllers.TimeSeriesC # join -> seq(feature) # when metricType == "drift" - will return time series list of drift values -GET /api/v1/join/:join/feature/:name/timeseries controllers.TimeSeriesController.fetchFeature(join: String, name: String, startTs: Long, endTs: Long, metricType: String, metrics: String, granularity: String, offset: Option[String], algorithm: Option[String]) +GET /api/v1/join/:join/feature/:name/timeseries controllers.TimeSeriesController.fetchFeature(join: String, name: String, startTs: Long, endTs: Long, metricType: String, metrics: String, granularity: String, offset: Option[String], algorithm: Option[String]) # TODO - move the core flow to fine-grained end-points diff --git a/hub/test/controllers/JoinControllerSpec.scala b/hub/test/controllers/JoinControllerSpec.scala new file mode 100644 index 0000000000..b6924cfdd1 --- /dev/null +++ b/hub/test/controllers/JoinControllerSpec.scala @@ -0,0 +1,65 @@ +package controllers + +import controllers.MockJoinService.mockJoinRegistry +import io.circe._ +import io.circe.generic.auto._ +import io.circe.parser._ +import model.ListJoinResponse +import org.mockito.Mockito._ +import org.scalatest.EitherValues +import org.scalatestplus.play._ +import play.api.http.Status.BAD_REQUEST +import play.api.http.Status.OK +import play.api.mvc._ +import play.api.test.Helpers._ +import play.api.test._ +import store.MonitoringModelStore + +class JoinControllerSpec extends PlaySpec with Results with EitherValues { + + // Create a stub ControllerComponents + val stubCC: ControllerComponents = stubControllerComponents() + // Create a mocked DynDB store + val mockedStore: MonitoringModelStore = mock(classOf[MonitoringModelStore]) + + val controller = new JoinController(stubCC, mockedStore) + + "JoinController" should { + + "send 400 on a bad offset" in { + val result = controller.list(Some(-1), Some(10)).apply(FakeRequest()) + status(result) mustBe BAD_REQUEST + } + + "send 400 on a bad limit" in { + val result = controller.list(Some(10), Some(-2)).apply(FakeRequest()) + status(result) mustBe BAD_REQUEST + } + + "send valid results on a correctly formed request" in { + when(mockedStore.getJoins).thenReturn(mockJoinRegistry) + + val result = controller.list(None, None).apply(FakeRequest()) + status(result) mustBe OK + val bodyText = contentAsString(result) + val listJoinResponse: Either[Error, ListJoinResponse] = decode[ListJoinResponse](bodyText) + val items = listJoinResponse.right.value.items + items.length mustBe controller.defaultLimit + items.map(_.name.toInt).toSet mustBe (0 until 10).toSet + } + + "send results in a paginated fashion correctly" in { + when(mockedStore.getJoins).thenReturn(mockJoinRegistry) + + val startOffset = 25 + val number = 20 + val result = controller.list(Some(startOffset), Some(number)).apply(FakeRequest()) + status(result) mustBe OK + val bodyText = contentAsString(result) + val listJoinResponse: Either[Error, ListJoinResponse] = decode[ListJoinResponse](bodyText) + val items = listJoinResponse.right.value.items + items.length mustBe number + items.map(_.name.toInt).toSet mustBe (startOffset until startOffset + number).toSet + } + } +} diff --git a/hub/test/controllers/SearchControllerSpec.scala b/hub/test/controllers/SearchControllerSpec.scala index 5510ea4010..b188336fa9 100644 --- a/hub/test/controllers/SearchControllerSpec.scala +++ b/hub/test/controllers/SearchControllerSpec.scala @@ -1,10 +1,10 @@ package controllers -import controllers.MockDataService.mockModelRegistry +import controllers.MockJoinService.mockJoinRegistry import io.circe._ import io.circe.generic.auto._ import io.circe.parser._ -import model.ListModelResponse +import model.{GroupBy, Join, ListJoinResponse} import org.mockito.Mockito.mock import org.mockito.Mockito.when import org.scalatest.EitherValues @@ -38,19 +38,19 @@ class SearchControllerSpec extends PlaySpec with Results with EitherValues { } "send valid results on a correctly formed request" in { - when(mockedStore.getModels).thenReturn(mockModelRegistry) + when(mockedStore.getJoins).thenReturn(mockJoinRegistry) val result = controller.search("1", None, None).apply(FakeRequest()) status(result) mustBe OK val bodyText = contentAsString(result) - val listModelResponse: Either[Error, ListModelResponse] = decode[ListModelResponse](bodyText) - val items = listModelResponse.right.value.items + val listJoinResponse: Either[Error, ListJoinResponse] = decode[ListJoinResponse](bodyText) + val items = listJoinResponse.right.value.items items.length mustBe controller.defaultLimit items.map(_.name.toInt).toSet mustBe Set(1, 10, 11, 12, 13, 14, 15, 16, 17, 18) } "send results in a paginated fashion correctly" in { - when(mockedStore.getModels).thenReturn(mockModelRegistry) + when(mockedStore.getJoins).thenReturn(mockJoinRegistry) val startOffset = 3 val number = 6 @@ -60,10 +60,19 @@ class SearchControllerSpec extends PlaySpec with Results with EitherValues { val expected = Set(12, 13, 14, 15, 16, 17) status(result) mustBe OK val bodyText = contentAsString(result) - val listModelResponse: Either[Error, ListModelResponse] = decode[ListModelResponse](bodyText) - val items = listModelResponse.right.value.items + val listJoinResponse: Either[Error, ListJoinResponse] = decode[ListJoinResponse](bodyText) + val items = listJoinResponse.right.value.items items.length mustBe number items.map(_.name.toInt).toSet mustBe expected } } } + +object MockJoinService { + def generateMockJoin(id: String): Join = { + val groupBys = Seq(GroupBy("my_groupBy", Seq("g1", "g2"))) + Join(id, Seq("ext_f1", "ext_f2", "d_1", "d2"), groupBys) + } + + val mockJoinRegistry: Seq[Join] = (0 until 100).map(i => generateMockJoin(i.toString)) +} From 622405af048fa131c95dbe61eb1eebd5041006f9 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Mon, 25 Nov 2024 10:04:16 -0500 Subject: [PATCH 17/37] Switch to correct metadata table --- online/src/main/scala/ai/chronon/online/MetadataEndPoint.scala | 2 +- spark/src/main/scala/ai/chronon/spark/Driver.scala | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/online/src/main/scala/ai/chronon/online/MetadataEndPoint.scala b/online/src/main/scala/ai/chronon/online/MetadataEndPoint.scala index cb1bc52e91..114d3b8ada 100644 --- a/online/src/main/scala/ai/chronon/online/MetadataEndPoint.scala +++ b/online/src/main/scala/ai/chronon/online/MetadataEndPoint.scala @@ -22,7 +22,7 @@ case class MetadataEndPoint[Conf <: TBase[_, _]: Manifest: ClassTag]( object MetadataEndPoint { @transient implicit lazy val logger: Logger = LoggerFactory.getLogger(getClass) - val ConfByKeyEndPointName = "ZIPLINE_METADATA" + val ConfByKeyEndPointName = "CHRONON_METADATA" val NameByTeamEndPointName = "CHRONON_ENTITY_BY_TEAM" private def getTeamFromMetadata(metaData: MetaData): String = { diff --git a/spark/src/main/scala/ai/chronon/spark/Driver.scala b/spark/src/main/scala/ai/chronon/spark/Driver.scala index 92489a90fb..103f382ee8 100644 --- a/spark/src/main/scala/ai/chronon/spark/Driver.scala +++ b/spark/src/main/scala/ai/chronon/spark/Driver.scala @@ -18,6 +18,7 @@ package ai.chronon.spark import ai.chronon.api import ai.chronon.api.Constants +import ai.chronon.api.Constants.MetadataDataset import ai.chronon.api.Extensions.GroupByOps import ai.chronon.api.Extensions.MetadataOps import ai.chronon.api.Extensions.SourceOps @@ -565,7 +566,7 @@ object Driver { lazy val api: Api = impl(serializableProps) def metaDataStore = - new MetadataStore(impl(serializableProps).genKvStore, "ZIPLINE_METADATA", timeoutMillis = 10000) + new MetadataStore(impl(serializableProps).genKvStore, MetadataDataset, timeoutMillis = 10000) def impl(props: Map[String, String]): Api = { val urls = Array(new File(onlineJar()).toURI.toURL) From db0619c26cb8a7355cbb80f6f04fc8e424cc787c Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Sun, 24 Nov 2024 15:32:54 -0800 Subject: [PATCH 18/37] observability script for demo --- docker-init/start.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker-init/start.sh b/docker-init/start.sh index 9f8b39d9f1..64b777a76e 100644 --- a/docker-init/start.sh +++ b/docker-init/start.sh @@ -19,7 +19,7 @@ fi # Load up metadata into DynamoDB echo "Loading metadata.." -if ! java -cp $SPARK_JAR:$CLASSPATH ai.chronon.spark.Driver metadata-upload --conf-path=/chronon_sample/production/ --online-jar=$CLOUD_AWS_JAR --online-class=$ONLINE_CLASS; then +if ! java -Dlog4j.configurationFile=log4j.properties -cp $SPARK_JAR:$CLASSPATH ai.chronon.spark.Driver metadata-upload --conf-path=/chronon_sample/production/ --online-jar=$CLOUD_AWS_JAR --online-class=$ONLINE_CLASS; then echo "Error: Failed to load metadata into DynamoDB" >&2 exit 1 fi @@ -27,7 +27,7 @@ echo "Metadata load completed successfully!" # Initialize DynamoDB echo "Initializing DynamoDB Table .." -if ! output=$(java -cp $SPARK_JAR:$CLASSPATH ai.chronon.spark.Driver create-summary-dataset \ +if ! output=$(java -Dlog4j.configurationFile=log4j.properties -cp $SPARK_JAR:$CLASSPATH ai.chronon.spark.Driver create-summary-dataset \ --online-jar=$CLOUD_AWS_JAR \ --online-class=$ONLINE_CLASS 2>&1); then echo "Error: Failed to bring up DynamoDB table" >&2 From 12db7cdd4cf936f660050584a1c87a5d6d475018 Mon Sep 17 00:00:00 2001 From: nikhil-zlai Date: Sun, 24 Nov 2024 18:32:26 -0800 Subject: [PATCH 19/37] running observability demo --- spark/src/main/scala/ai/chronon/spark/TableUtils.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala index 09aca39240..d516d7a5f0 100644 --- a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala @@ -327,9 +327,9 @@ case class TableUtils(sparkSession: SparkSession) { sql(creationSql) } catch { case _: TableAlreadyExistsException => - logger.info(s"Table $tableName already exists, skipping creation") + println(s"Table $tableName already exists, skipping creation") case e: Exception => - logger.error(s"Failed to create table $tableName", e) + println(s"Failed to create table $tableName", e) throw e } } @@ -357,6 +357,7 @@ case class TableUtils(sparkSession: SparkSession) { // so that an exception will be thrown below dfRearranged } + println(s"Repartitioning and writing into table $tableName".yellow) repartitionAndWrite(finalizedDf, tableName, saveMode, stats, sortByCols) } From 0a8c8b327678578f53a1e9938d81d70b932e7ef9 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Mon, 25 Nov 2024 22:37:56 -0500 Subject: [PATCH 20/37] Add support for in-memory controller + kv store module --- build.sbt | 16 +++-- .../controllers/InMemKVStoreController.scala | 37 ++++++++++ hub/app/module/DriftStoreModule.scala | 18 +++-- hub/app/module/InMemoryKVStoreModule.scala | 18 +++++ hub/conf/application.conf | 1 + hub/conf/routes | 2 + .../scala/ai/chronon/online/HTTPKVStore.scala | 69 +++++++++++++++++++ 7 files changed, 151 insertions(+), 10 deletions(-) create mode 100644 hub/app/controllers/InMemKVStoreController.scala create mode 100644 hub/app/module/InMemoryKVStoreModule.scala create mode 100644 online/src/main/scala/ai/chronon/online/HTTPKVStore.scala diff --git a/build.sbt b/build.sbt index 6c60365dcb..139945dcba 100644 --- a/build.sbt +++ b/build.sbt @@ -80,6 +80,12 @@ val jackson = Seq( "com.fasterxml.jackson.module" %% "jackson-module-scala" ).map(_ % jackson_2_15) +val circe = Seq( + "io.circe" %% "circe-core", + "io.circe" %% "circe-generic", + "io.circe" %% "circe-parser", +).map(_ % circeVersion) + val flink_all = Seq( "org.apache.flink" %% "flink-streaming-scala", "org.apache.flink" % "flink-metrics-dropwizard", @@ -129,6 +135,10 @@ lazy val online = project "com.github.ben-manes.caffeine" % "caffeine" % "3.1.8" ), libraryDependencies ++= jackson, + // we pull in circe to help us ser case classes like PutRequest without requiring annotations + libraryDependencies ++= circe, + // dep needed for HTTPKvStore - yank when we rip this out + libraryDependencies += "com.softwaremill.sttp.client3" %% "core" % "3.9.7", libraryDependencies ++= spark_all.map(_ % "provided"), libraryDependencies ++= flink_all.map(_ % "provided") ) @@ -236,20 +246,18 @@ lazy val frontend = (project in file("frontend")) // build interop between one module solely on 2.13 and others on 2.12 is painful lazy val hub = (project in file("hub")) .enablePlugins(PlayScala) - .dependsOn(cloud_aws) + .dependsOn(cloud_aws, spark) .settings( name := "hub", libraryDependencies ++= Seq( guice, "org.scalatestplus.play" %% "scalatestplus-play" % "5.1.0" % Test, "org.scalatestplus" %% "mockito-3-4" % "3.2.10.0" % "test", - "io.circe" %% "circe-core" % circeVersion, - "io.circe" %% "circe-generic" % circeVersion, - "io.circe" %% "circe-parser" % circeVersion, "org.scala-lang.modules" %% "scala-xml" % "2.1.0", "org.scala-lang.modules" %% "scala-parser-combinators" % "2.3.0", "org.scala-lang.modules" %% "scala-java8-compat" % "1.0.2" ), + libraryDependencies ++= circe, libraryDependencySchemes ++= Seq( "org.scala-lang.modules" %% "scala-xml" % VersionScheme.Always, "org.scala-lang.modules" %% "scala-parser-combinators" % VersionScheme.Always, diff --git a/hub/app/controllers/InMemKVStoreController.scala b/hub/app/controllers/InMemKVStoreController.scala new file mode 100644 index 0000000000..82e3f1979d --- /dev/null +++ b/hub/app/controllers/InMemKVStoreController.scala @@ -0,0 +1,37 @@ +package controllers + +import ai.chronon.online.KVStore +import ai.chronon.online.KVStore.PutRequest +import play.api.mvc.{BaseController, ControllerComponents} +import io.circe.parser.decode +import play.api.Logger + +import javax.inject.Inject +import scala.concurrent.{ExecutionContext, Future} + +class InMemKVStoreController @Inject() (val controllerComponents: ControllerComponents, kvStore: KVStore)(implicit ec: ExecutionContext) extends BaseController { + + import ai.chronon.online.PutRequestCodec._ + + val logger: Logger = Logger(this.getClass) + + def bulkPut() = Action(parse.raw).async { request => + request.body.asBytes() match { + case Some(bytes) => + decode[Array[PutRequest]](bytes.utf8String) match { + case Right(putRequests) => + logger.info(s"Attempting a bulkPut with ${putRequests.length} items") + val resultFuture = kvStore.multiPut(putRequests) + resultFuture.map { + responses => + if (responses.contains(false)) { + logger.warn(s"Some write failures encountered") + } + Ok("Success") + } + case Left(error) => Future.successful(BadRequest(error.getMessage)) + } + case None => Future.successful(BadRequest("Empty body")) + } + } +} diff --git a/hub/app/module/DriftStoreModule.scala b/hub/app/module/DriftStoreModule.scala index ee831e5cad..b8c12786e3 100644 --- a/hub/app/module/DriftStoreModule.scala +++ b/hub/app/module/DriftStoreModule.scala @@ -1,15 +1,21 @@ package module -import ai.chronon.integrations.aws.AwsApiImpl +import ai.chronon.online.KVStore import ai.chronon.online.stats.DriftStore import com.google.inject.AbstractModule -import play.api.{Configuration, Environment} -class DriftStoreModule(environment: Environment, configuration: Configuration) extends AbstractModule { +import javax.inject.{Inject, Provider} + +class DriftStoreModule extends AbstractModule { override def configure(): Unit = { - val awsApiImpl = new AwsApiImpl(Map.empty) - val driftStore = new DriftStore(awsApiImpl.genKvStore) - bind(classOf[DriftStore]).toInstance(driftStore) + // TODO swap to concrete api impl in a follow up + bind(classOf[DriftStore]).toProvider(classOf[DriftStoreProvider]).asEagerSingleton() + } +} + +class DriftStoreProvider @Inject()(kvStore: KVStore) extends Provider[DriftStore] { + override def get(): DriftStore = { + new DriftStore(kvStore) } } diff --git a/hub/app/module/InMemoryKVStoreModule.scala b/hub/app/module/InMemoryKVStoreModule.scala new file mode 100644 index 0000000000..2467ca4573 --- /dev/null +++ b/hub/app/module/InMemoryKVStoreModule.scala @@ -0,0 +1,18 @@ +package module + +import ai.chronon.api.Constants +import ai.chronon.online.KVStore +import ai.chronon.spark.utils.InMemoryKvStore +import com.google.inject.AbstractModule + +// Module that creates and injects an in-memory kv store implementation to allow for quick docker testing +class InMemoryKVStoreModule extends AbstractModule { + + override def configure(): Unit = { + val inMemoryKVStore = InMemoryKvStore.build("hub", () => null) + // create relevant datasets in kv store + inMemoryKVStore.create(Constants.MetadataDataset) + inMemoryKVStore.create(Constants.TiledSummaryDataset) + bind(classOf[KVStore]).toInstance(inMemoryKVStore) + } +} diff --git a/hub/conf/application.conf b/hub/conf/application.conf index 5696edf2d2..292f0d42fc 100644 --- a/hub/conf/application.conf +++ b/hub/conf/application.conf @@ -29,4 +29,5 @@ play.filters.cors { # Add DynamoDB module play.modules.enabled += "module.ModelStoreModule" +play.modules.enabled += "module.InMemoryKVStoreModule" play.modules.enabled += "module.DriftStoreModule" diff --git a/hub/conf/routes b/hub/conf/routes index 5d24c3bd52..8939447e6e 100644 --- a/hub/conf/routes +++ b/hub/conf/routes @@ -15,6 +15,8 @@ GET /api/v1/join/:name/timeseries controllers.TimeSeriesC # when metricType == "drift" - will return time series list of drift values GET /api/v1/join/:join/feature/:name/timeseries controllers.TimeSeriesController.fetchFeature(join: String, name: String, startTs: Long, endTs: Long, metricType: String, metrics: String, granularity: String, offset: Option[String], algorithm: Option[String]) +# Temporary in-memory kv store endpoint +POST /api/v1/dataset/data controllers.InMemKVStoreController.bulkPut() # TODO - move the core flow to fine-grained end-points #GET /api/v1/feature/:name/timeseries controllers.TimeSeriesController.fetchFeature(name: String, startTs: Long, endTs: Long, metricType: String, metrics: String, granularity: String, offset: Option[String], algorithm: Option[String]) diff --git a/online/src/main/scala/ai/chronon/online/HTTPKVStore.scala b/online/src/main/scala/ai/chronon/online/HTTPKVStore.scala new file mode 100644 index 0000000000..0514a5aab5 --- /dev/null +++ b/online/src/main/scala/ai/chronon/online/HTTPKVStore.scala @@ -0,0 +1,69 @@ +package ai.chronon.online + +import ai.chronon.online.KVStore.PutRequest +import io.circe._ +import io.circe.generic.semiauto._ +import io.circe.syntax._ +import sttp.client3._ +import sttp.model.StatusCode + +import java.util.Base64 +import scala.concurrent.Future + +// Hacky test kv store that we use to send objects to the in-memory KV store that lives in a different JVM (e.g spark -> hub) +class HTTPKVStore(host: String = "localhost", port: Int = 9000) extends KVStore with Serializable { + import PutRequestCodec._ + + val backend = HttpClientSyncBackend() + val baseUrl = s"http://$host:$port/api/v1/dataset" + + override def multiGet(requests: collection.Seq[KVStore.GetRequest]): Future[collection.Seq[KVStore.GetResponse]] = ??? + + override def multiPut(putRequests: collection.Seq[KVStore.PutRequest]): Future[collection.Seq[Boolean]] = { + if (putRequests.isEmpty) { + Future.successful(Seq.empty) + } else { + // typically should see the same dataset but we break up our calls by dataset to be safe + val requestsByDataset = putRequests.groupBy(_.dataset) + val futures: Seq[Future[Boolean]] = requestsByDataset.map { + case (dataset, requests) => + Future { + basicRequest + .post(uri"$baseUrl/$dataset/data") + .header("Content-Type", "application/json") + .body(requests.asJson.noSpaces) + .send(backend) + }.map { + response => + response.code match { + case StatusCode.Ok => true + case _ => + logger.error(s"HTTP multiPut failed with status ${response.code}: ${response.body}") + false + } + } + }.toSeq + + Future.sequence(futures) + } + } + + override def bulkPut(sourceOfflineTable: String, destinationOnlineDataSet: String, partition: String): Unit = ??? + + override def create(dataset: String): Unit = { + logger.warn(s"Skipping creation of $dataset in HTTP kv store implementation") + } +} + +object PutRequestCodec { + // Custom codec for byte arrays using Base64 + implicit val byteArrayEncoder: Encoder[Array[Byte]] = + Encoder.encodeString.contramap[Array[Byte]](Base64.getEncoder.encodeToString) + + implicit val byteArrayDecoder: Decoder[Array[Byte]] = + Decoder.decodeString.map(Base64.getDecoder.decode) + + // Derive codec for PutRequest + implicit val putRequestCodec: Codec[PutRequest] = deriveCodec[PutRequest] +} + From e7e2d1685414e9dfa771b48e4ca8804e9d4414b6 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Tue, 26 Nov 2024 10:30:43 -0500 Subject: [PATCH 21/37] Clean up scripts to load data and query via time series controller --- docker-init/Dockerfile | 1 + docker-init/demo/README.md | 20 +++++++ docker-init/demo/build.sh | 1 - docker-init/demo/load_summaries.sh | 12 ++++ docker-init/start.sh | 21 +------ .../controllers/InMemKVStoreController.scala | 17 +++++- .../controllers/TimeSeriesController.scala | 10 ++-- .../scala/ai/chronon/online/HTTPKVStore.scala | 57 ++++++++----------- .../spark/scripts/ObservabilityDemo.scala | 4 +- 9 files changed, 83 insertions(+), 60 deletions(-) delete mode 100755 docker-init/demo/build.sh create mode 100755 docker-init/demo/load_summaries.sh diff --git a/docker-init/Dockerfile b/docker-init/Dockerfile index e9b36c564d..b8b24da7d4 100644 --- a/docker-init/Dockerfile +++ b/docker-init/Dockerfile @@ -43,6 +43,7 @@ ENV CHRONON_DRIVER_JAR="/app/cli/spark.jar" # Set up Spark dependencies to help with launching CLI # Copy Spark JARs from the Bitnami image COPY --from=spark-source /opt/bitnami/spark/jars /opt/spark/jars +COPY --from=spark-source /opt/bitnami/spark/bin /opt/spark/bin # Add all Spark JARs to the classpath ENV CLASSPATH=/opt/spark/jars/* diff --git a/docker-init/demo/README.md b/docker-init/demo/README.md index c1abae2d9b..a3f0807eae 100644 --- a/docker-init/demo/README.md +++ b/docker-init/demo/README.md @@ -1,5 +1,25 @@ +# Populate Observability Demo Data +To populate the observability demo data: +* Launch the set of docker containers: +```bash +~/workspace/chronon $ docker-compose -f docker-init/compose.yaml up --build +... +app-1 | [info] 2024-11-26 05:10:45,758 [main] INFO play.api.Play - Application started (Prod) (no global state) +app-1 | [info] 2024-11-26 05:10:45,958 [main] INFO play.core.server.AkkaHttpServer - Listening for HTTP on /[0:0:0:0:0:0:0:0]:9000 +``` +(you can skip the --build if you don't wish to rebuild your code) + +Now you can trigger the script to load summary data: +```bash +~/workspace/chronon $ docker-init/demo/load_summaries.sh +... +Done uploading summaries! 🥳 +``` + +# Streamlit local experimentation run build.sh once, and you can repeatedly exec to quickly visualize In first terminal: `sbt spark/assembly` In second terminal: `./run.sh` to load the built jar and serve the data on localhost:8181 In third terminal: `streamlit run viz.py` + diff --git a/docker-init/demo/build.sh b/docker-init/demo/build.sh deleted file mode 100755 index 5627dac2f5..0000000000 --- a/docker-init/demo/build.sh +++ /dev/null @@ -1 +0,0 @@ -docker build -t obs . \ No newline at end of file diff --git a/docker-init/demo/load_summaries.sh b/docker-init/demo/load_summaries.sh new file mode 100755 index 0000000000..61b4d9db95 --- /dev/null +++ b/docker-init/demo/load_summaries.sh @@ -0,0 +1,12 @@ +# Kick off the ObsDemo spark job in the app container + +docker-compose -f docker-init/compose.yaml exec app /opt/spark/bin/spark-submit \ + --master "local[*]" \ + --driver-memory 8g \ + --conf "spark.driver.maxResultSize=6g" \ + --conf "spark.driver.memory=8g" \ + --driver-class-path "/opt/spark/jars/*:/app/cli/*" \ + --conf "spark.driver.host=localhost" \ + --conf "spark.driver.bindAddress=0.0.0.0" \ + --class ai.chronon.spark.scripts.ObservabilityDemo \ + /app/cli/spark.jar diff --git a/docker-init/start.sh b/docker-init/start.sh index 64b777a76e..a3340f894a 100644 --- a/docker-init/start.sh +++ b/docker-init/start.sh @@ -19,7 +19,7 @@ fi # Load up metadata into DynamoDB echo "Loading metadata.." -if ! java -Dlog4j.configurationFile=log4j.properties -cp $SPARK_JAR:$CLASSPATH ai.chronon.spark.Driver metadata-upload --conf-path=/chronon_sample/production/ --online-jar=$CLOUD_AWS_JAR --online-class=$ONLINE_CLASS; then +if ! java -cp $SPARK_JAR:$CLASSPATH ai.chronon.spark.Driver metadata-upload --conf-path=/chronon_sample/production/ --online-jar=$CLOUD_AWS_JAR --online-class=$ONLINE_CLASS; then echo "Error: Failed to load metadata into DynamoDB" >&2 exit 1 fi @@ -27,7 +27,7 @@ echo "Metadata load completed successfully!" # Initialize DynamoDB echo "Initializing DynamoDB Table .." -if ! output=$(java -Dlog4j.configurationFile=log4j.properties -cp $SPARK_JAR:$CLASSPATH ai.chronon.spark.Driver create-summary-dataset \ +if ! output=$(java -cp $SPARK_JAR:$CLASSPATH ai.chronon.spark.Driver create-summary-dataset \ --online-jar=$CLOUD_AWS_JAR \ --online-class=$ONLINE_CLASS 2>&1); then echo "Error: Failed to bring up DynamoDB table" >&2 @@ -39,23 +39,6 @@ echo "DynamoDB Table created successfully!" start_time=$(date +%s) -if ! java \ - --add-opens=java.base/sun.nio.ch=ALL-UNNAMED \ - --add-opens=java.base/sun.security.action=ALL-UNNAMED \ - -cp $SPARK_JAR:$CLASSPATH ai.chronon.spark.Driver summarize-and-upload \ - --online-jar=$CLOUD_AWS_JAR \ - --online-class=$ONLINE_CLASS \ - --parquet-path="$(pwd)/drift_data" \ - --conf-path=/chronon_sample/production/ \ - --time-column=transaction_time; then - echo "Error: Failed to load summary data into DynamoDB" >&2 - exit 1 -else - end_time=$(date +%s) - elapsed_time=$((end_time - start_time)) - echo "Summary load completed successfully! Took $elapsed_time seconds." -fi - # Add these java options as without them we hit the below error: # throws java.lang.ClassFormatError accessible: module java.base does not "opens java.lang" to unnamed module @36328710 export JAVA_OPTS="--add-opens java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED" diff --git a/hub/app/controllers/InMemKVStoreController.scala b/hub/app/controllers/InMemKVStoreController.scala index 82e3f1979d..ba6b226fec 100644 --- a/hub/app/controllers/InMemKVStoreController.scala +++ b/hub/app/controllers/InMemKVStoreController.scala @@ -2,16 +2,19 @@ package controllers import ai.chronon.online.KVStore import ai.chronon.online.KVStore.PutRequest +import io.circe.generic.semiauto.deriveCodec +import io.circe.{Codec, Decoder, Encoder} import play.api.mvc.{BaseController, ControllerComponents} import io.circe.parser.decode import play.api.Logger +import java.util.Base64 import javax.inject.Inject import scala.concurrent.{ExecutionContext, Future} class InMemKVStoreController @Inject() (val controllerComponents: ControllerComponents, kvStore: KVStore)(implicit ec: ExecutionContext) extends BaseController { - import ai.chronon.online.PutRequestCodec._ + import PutRequestCodec._ val logger: Logger = Logger(this.getClass) @@ -35,3 +38,15 @@ class InMemKVStoreController @Inject() (val controllerComponents: ControllerComp } } } + +object PutRequestCodec { + // Custom codec for byte arrays using Base64 + implicit val byteArrayEncoder: Encoder[Array[Byte]] = + Encoder.encodeString.contramap[Array[Byte]](Base64.getEncoder.encodeToString) + + implicit val byteArrayDecoder: Decoder[Array[Byte]] = + Decoder.decodeString.map(Base64.getDecoder.decode) + + // Derive codec for PutRequest + implicit val putRequestCodec: Codec[PutRequest] = deriveCodec[PutRequest] +} diff --git a/hub/app/controllers/TimeSeriesController.scala b/hub/app/controllers/TimeSeriesController.scala index 625fcd729e..2ceaa6bd47 100644 --- a/hub/app/controllers/TimeSeriesController.scala +++ b/hub/app/controllers/TimeSeriesController.scala @@ -83,7 +83,8 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon case (_, None) => Future.successful(BadRequest("Invalid drift algorithm. Expect JSD, PSI or Hellinger")) case (Some(o), Some(driftMetric)) => val window = new Window(o.toMinutes.toInt, TimeUnit.MINUTES) - val maybeDriftSeries = driftStore.getDriftSeries(name, driftMetric, window, startTs, endTs) + val joinPath = name.replaceFirst("\\.", "/") // we need to look up in the drift store with this transformed name + val maybeDriftSeries = driftStore.getDriftSeries(joinPath, driftMetric, window, startTs, endTs) maybeDriftSeries match { case Failure(exception) => Future.successful(InternalServerError(s"Error computing join drift - ${exception.getMessage}")) case Success(driftSeriesFuture) => driftSeriesFuture.map { @@ -140,9 +141,10 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon case (_, None) => Future.successful(BadRequest("Invalid drift algorithm. Expect JSD, PSI or Hellinger")) case (Some(o), Some(driftMetric)) => val window = new Window(o.toMinutes.toInt, TimeUnit.MINUTES) + val joinPath = join.replaceFirst("\\.", "/") // we need to look up in the drift store with this transformed name if (granularity == Aggregates) { val maybeDriftSeries = - driftStore.getDriftSeries(join, driftMetric, window, startTs, endTs, Some(name)) + driftStore.getDriftSeries(joinPath, driftMetric, window, startTs, endTs, Some(name)) maybeDriftSeries match { case Failure(exception) => Future.successful(InternalServerError(s"Error computing feature drift - ${exception.getMessage}")) case Success(driftSeriesFuture) => driftSeriesFuture.map { @@ -153,8 +155,8 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon } } else { // percentiles - val maybeCurrentSummarySeries = driftStore.getSummarySeries(join, startTs, endTs, Some(name)) - val maybeBaselineSummarySeries = driftStore.getSummarySeries(join, startTs - window.millis, endTs - window.millis, Some(name)) + val maybeCurrentSummarySeries = driftStore.getSummarySeries(joinPath, startTs, endTs, Some(name)) + val maybeBaselineSummarySeries = driftStore.getSummarySeries(joinPath, startTs - window.millis, endTs - window.millis, Some(name)) (maybeCurrentSummarySeries, maybeBaselineSummarySeries) match { case (Failure(exceptionA), Failure(exceptionB)) => Future.successful(InternalServerError(s"Error computing feature percentiles for current + offset time window.\nCurrent window error: ${exceptionA.getMessage}\nOffset window error: ${exceptionB.getMessage}")) case (_, Failure(exception)) => Future.successful(InternalServerError(s"Error computing feature percentiles for offset time window - ${exception.getMessage}")) diff --git a/online/src/main/scala/ai/chronon/online/HTTPKVStore.scala b/online/src/main/scala/ai/chronon/online/HTTPKVStore.scala index 0514a5aab5..ff7fbdddf2 100644 --- a/online/src/main/scala/ai/chronon/online/HTTPKVStore.scala +++ b/online/src/main/scala/ai/chronon/online/HTTPKVStore.scala @@ -1,8 +1,6 @@ package ai.chronon.online import ai.chronon.online.KVStore.PutRequest -import io.circe._ -import io.circe.generic.semiauto._ import io.circe.syntax._ import sttp.client3._ import sttp.model.StatusCode @@ -12,7 +10,6 @@ import scala.concurrent.Future // Hacky test kv store that we use to send objects to the in-memory KV store that lives in a different JVM (e.g spark -> hub) class HTTPKVStore(host: String = "localhost", port: Int = 9000) extends KVStore with Serializable { - import PutRequestCodec._ val backend = HttpClientSyncBackend() val baseUrl = s"http://$host:$port/api/v1/dataset" @@ -23,28 +20,21 @@ class HTTPKVStore(host: String = "localhost", port: Int = 9000) extends KVStore if (putRequests.isEmpty) { Future.successful(Seq.empty) } else { - // typically should see the same dataset but we break up our calls by dataset to be safe - val requestsByDataset = putRequests.groupBy(_.dataset) - val futures: Seq[Future[Boolean]] = requestsByDataset.map { - case (dataset, requests) => - Future { - basicRequest - .post(uri"$baseUrl/$dataset/data") - .header("Content-Type", "application/json") - .body(requests.asJson.noSpaces) - .send(backend) - }.map { - response => - response.code match { - case StatusCode.Ok => true - case _ => - logger.error(s"HTTP multiPut failed with status ${response.code}: ${response.body}") - false - } + Future { + basicRequest + .post(uri"$baseUrl/data") + .header("Content-Type", "application/json") + .body(jsonList(putRequests)) + .send(backend) + }.map { + response => + response.code match { + case StatusCode.Ok => Seq(true) + case _ => + logger.error(s"HTTP multiPut failed with status ${response.code}: ${response.body}") + Seq(false) } - }.toSeq - - Future.sequence(futures) + } } } @@ -53,17 +43,18 @@ class HTTPKVStore(host: String = "localhost", port: Int = 9000) extends KVStore override def create(dataset: String): Unit = { logger.warn(s"Skipping creation of $dataset in HTTP kv store implementation") } -} -object PutRequestCodec { - // Custom codec for byte arrays using Base64 - implicit val byteArrayEncoder: Encoder[Array[Byte]] = - Encoder.encodeString.contramap[Array[Byte]](Base64.getEncoder.encodeToString) + // wire up json conversion manually to side step serialization issues in spark executors + def jsonString(request: PutRequest): String = { + val keyBase64 = Base64.getEncoder.encodeToString(request.keyBytes) + val valueBase64 = Base64.getEncoder.encodeToString(request.valueBytes) + s"""{ "keyBytes": "${keyBase64}", "valueBytes": "${valueBase64}", "dataset": "${request.dataset}", "tsMillis": ${request.tsMillis.orNull}}""".stripMargin + } - implicit val byteArrayDecoder: Decoder[Array[Byte]] = - Decoder.decodeString.map(Base64.getDecoder.decode) + def jsonList(requests: Seq[PutRequest]): String = { + val requestsJson = requests.map(jsonString(_)).mkString(", ") - // Derive codec for PutRequest - implicit val putRequestCodec: Codec[PutRequest] = deriveCodec[PutRequest] + s"[ $requestsJson ]" + } } diff --git a/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemo.scala b/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemo.scala index 064b69700f..b2e5015bf6 100644 --- a/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemo.scala +++ b/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemo.scala @@ -10,7 +10,7 @@ import ai.chronon.api.PartitionSpec import ai.chronon.api.TileDriftSeries import ai.chronon.api.TileSummarySeries import ai.chronon.api.Window -import ai.chronon.online.KVStore +import ai.chronon.online.{HTTPKVStore, KVStore} import ai.chronon.online.stats.DriftStore import ai.chronon.spark.SparkSessionBuilder import ai.chronon.spark.TableUtils @@ -56,7 +56,7 @@ object ObservabilityDemo { // mock api impl for online fetching and uploading val kvStoreFunc: () => KVStore = () => { // cannot reuse the variable - or serialization error - val result = InMemoryKvStore.build(namespace, () => null) + val result = new HTTPKVStore() result } val api = new MockApi(kvStoreFunc, namespace) From 281fc7c8dda3589055200910534698007d12b68d Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Tue, 26 Nov 2024 10:48:32 -0500 Subject: [PATCH 22/37] Address scalafix + fmt --- .../controllers/InMemKVStoreController.scala | 48 ++++---- hub/app/controllers/JoinController.scala | 3 +- hub/app/controllers/ModelController.scala | 3 +- hub/app/controllers/SearchController.scala | 6 +- .../controllers/TimeSeriesController.scala | 105 +++++++++++------- hub/app/module/DriftStoreModule.scala | 5 +- .../controllers/SearchControllerSpec.scala | 4 +- .../TimeSeriesControllerSpec.scala | 18 ++- .../scala/ai/chronon/online/HTTPKVStore.scala | 21 ++-- .../ai/chronon/online/stats/DriftStore.scala | 25 ++++- .../spark/scripts/ObservabilityDemo.scala | 18 +-- 11 files changed, 148 insertions(+), 108 deletions(-) diff --git a/hub/app/controllers/InMemKVStoreController.scala b/hub/app/controllers/InMemKVStoreController.scala index ba6b226fec..1101b0454c 100644 --- a/hub/app/controllers/InMemKVStoreController.scala +++ b/hub/app/controllers/InMemKVStoreController.scala @@ -2,41 +2,49 @@ package controllers import ai.chronon.online.KVStore import ai.chronon.online.KVStore.PutRequest +import io.circe.Codec +import io.circe.Decoder +import io.circe.Encoder import io.circe.generic.semiauto.deriveCodec -import io.circe.{Codec, Decoder, Encoder} -import play.api.mvc.{BaseController, ControllerComponents} import io.circe.parser.decode import play.api.Logger +import play.api.mvc.BaseController +import play.api.mvc.ControllerComponents import java.util.Base64 import javax.inject.Inject -import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.ExecutionContext +import scala.concurrent.Future +import play.api.mvc +import play.api.mvc.RawBuffer -class InMemKVStoreController @Inject() (val controllerComponents: ControllerComponents, kvStore: KVStore)(implicit ec: ExecutionContext) extends BaseController { +class InMemKVStoreController @Inject() (val controllerComponents: ControllerComponents, kvStore: KVStore)(implicit + ec: ExecutionContext) + extends BaseController { import PutRequestCodec._ val logger: Logger = Logger(this.getClass) - def bulkPut() = Action(parse.raw).async { request => - request.body.asBytes() match { - case Some(bytes) => - decode[Array[PutRequest]](bytes.utf8String) match { - case Right(putRequests) => - logger.info(s"Attempting a bulkPut with ${putRequests.length} items") - val resultFuture = kvStore.multiPut(putRequests) - resultFuture.map { - responses => + def bulkPut(): mvc.Action[RawBuffer] = + Action(parse.raw).async { request => + request.body.asBytes() match { + case Some(bytes) => + decode[Array[PutRequest]](bytes.utf8String) match { + case Right(putRequests) => + logger.info(s"Attempting a bulkPut with ${putRequests.length} items") + val resultFuture = kvStore.multiPut(putRequests) + resultFuture.map { responses => if (responses.contains(false)) { - logger.warn(s"Some write failures encountered") + logger.warn("Some write failures encountered") } - Ok("Success") - } - case Left(error) => Future.successful(BadRequest(error.getMessage)) - } - case None => Future.successful(BadRequest("Empty body")) + Ok("Success") + } + case Left(error) => Future.successful(BadRequest(error.getMessage)) + } + case None => Future.successful(BadRequest("Empty body")) + } } - } } object PutRequestCodec { diff --git a/hub/app/controllers/JoinController.scala b/hub/app/controllers/JoinController.scala index 81fd73a970..f65de2976c 100644 --- a/hub/app/controllers/JoinController.scala +++ b/hub/app/controllers/JoinController.scala @@ -12,8 +12,7 @@ import javax.inject._ * Controller for the Zipline Join entities */ @Singleton -class JoinController @Inject()(val controllerComponents: ControllerComponents, - monitoringStore: MonitoringModelStore) +class JoinController @Inject() (val controllerComponents: ControllerComponents, monitoringStore: MonitoringModelStore) extends BaseController with Paginate { diff --git a/hub/app/controllers/ModelController.scala b/hub/app/controllers/ModelController.scala index 40ef41a56c..66e197191e 100644 --- a/hub/app/controllers/ModelController.scala +++ b/hub/app/controllers/ModelController.scala @@ -12,8 +12,7 @@ import javax.inject._ * Controller for the Zipline models entities */ @Singleton -class ModelController @Inject() (val controllerComponents: ControllerComponents, - monitoringStore: MonitoringModelStore) +class ModelController @Inject() (val controllerComponents: ControllerComponents, monitoringStore: MonitoringModelStore) extends BaseController with Paginate { diff --git a/hub/app/controllers/SearchController.scala b/hub/app/controllers/SearchController.scala index a6bd1a8ead..075d03de6b 100644 --- a/hub/app/controllers/SearchController.scala +++ b/hub/app/controllers/SearchController.scala @@ -2,7 +2,8 @@ package controllers import io.circe.generic.auto._ import io.circe.syntax._ -import model.{Join, SearchJoinResponse} +import model.Join +import model.SearchJoinResponse import play.api.mvc._ import store.MonitoringModelStore @@ -11,8 +12,7 @@ import javax.inject._ /** * Controller to power search related APIs */ -class SearchController @Inject() (val controllerComponents: ControllerComponents, - monitoringStore: MonitoringModelStore) +class SearchController @Inject() (val controllerComponents: ControllerComponents, monitoringStore: MonitoringModelStore) extends BaseController with Paginate { diff --git a/hub/app/controllers/TimeSeriesController.scala b/hub/app/controllers/TimeSeriesController.scala index 2ceaa6bd47..ca5207c688 100644 --- a/hub/app/controllers/TimeSeriesController.scala +++ b/hub/app/controllers/TimeSeriesController.scala @@ -1,6 +1,10 @@ package controllers +import ai.chronon.api.DriftMetric import ai.chronon.api.Extensions.WindowOps -import ai.chronon.api.{DriftMetric, TileDriftSeries, TileSummarySeries, TimeUnit, Window} +import ai.chronon.api.TileDriftSeries +import ai.chronon.api.TileSummarySeries +import ai.chronon.api.TimeUnit +import ai.chronon.api.Window import ai.chronon.online.stats.DriftStore import io.circe.generic.auto._ import io.circe.syntax._ @@ -8,16 +12,20 @@ import model._ import play.api.mvc._ import javax.inject._ -import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.ExecutionContext +import scala.concurrent.Future import scala.concurrent.duration._ -import scala.util.{Failure, Random, Success} import scala.jdk.CollectionConverters._ +import scala.util.Failure +import scala.util.Success /** * Controller that serves various time series endpoints at the model, join and feature level */ @Singleton -class TimeSeriesController @Inject() (val controllerComponents: ControllerComponents, driftStore: DriftStore)(implicit ec: ExecutionContext) extends BaseController { +class TimeSeriesController @Inject() (val controllerComponents: ControllerComponents, driftStore: DriftStore)(implicit + ec: ExecutionContext) + extends BaseController { import TimeSeriesController._ @@ -79,25 +87,29 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon algorithm: Option[String]): Future[Result] = { (parseOffset(offset), parseAlgorithm(algorithm)) match { - case (None, _) => Future.successful(BadRequest(s"Unable to parse offset - $offset")) - case (_, None) => Future.successful(BadRequest("Invalid drift algorithm. Expect JSD, PSI or Hellinger")) + case (None, _) => Future.successful(BadRequest(s"Unable to parse offset - $offset")) + case (_, None) => Future.successful(BadRequest("Invalid drift algorithm. Expect JSD, PSI or Hellinger")) case (Some(o), Some(driftMetric)) => val window = new Window(o.toMinutes.toInt, TimeUnit.MINUTES) val joinPath = name.replaceFirst("\\.", "/") // we need to look up in the drift store with this transformed name val maybeDriftSeries = driftStore.getDriftSeries(joinPath, driftMetric, window, startTs, endTs) maybeDriftSeries match { - case Failure(exception) => Future.successful(InternalServerError(s"Error computing join drift - ${exception.getMessage}")) - case Success(driftSeriesFuture) => driftSeriesFuture.map { - driftSeries => + case Failure(exception) => + Future.successful(InternalServerError(s"Error computing join drift - ${exception.getMessage}")) + case Success(driftSeriesFuture) => + driftSeriesFuture.map { driftSeries => // pull up a list of drift series objects for all the features in a group val grpToDriftSeriesList: Map[String, Seq[TileDriftSeries]] = driftSeries.groupBy(_.key.groupName) val groupByTimeSeries = grpToDriftSeriesList.map { - case (name, featureDriftSeriesInfoSeq) => GroupByTimeSeries(name, featureDriftSeriesInfoSeq.map(series => convertTileDriftSeriesInfoToTimeSeries(series, metric))) + case (name, featureDriftSeriesInfoSeq) => + GroupByTimeSeries( + name, + featureDriftSeriesInfoSeq.map(series => convertTileDriftSeriesInfoToTimeSeries(series, metric))) }.toSeq - val tsData = JoinTimeSeriesResponse(name, groupByTimeSeries) - Ok(tsData.asJson.noSpaces) - } + val tsData = JoinTimeSeriesResponse(name, groupByTimeSeries) + Ok(tsData.asJson.noSpaces) + } } } } @@ -141,35 +153,45 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon case (_, None) => Future.successful(BadRequest("Invalid drift algorithm. Expect JSD, PSI or Hellinger")) case (Some(o), Some(driftMetric)) => val window = new Window(o.toMinutes.toInt, TimeUnit.MINUTES) - val joinPath = join.replaceFirst("\\.", "/") // we need to look up in the drift store with this transformed name + val joinPath = + join.replaceFirst("\\.", "/") // we need to look up in the drift store with this transformed name if (granularity == Aggregates) { val maybeDriftSeries = driftStore.getDriftSeries(joinPath, driftMetric, window, startTs, endTs, Some(name)) maybeDriftSeries match { - case Failure(exception) => Future.successful(InternalServerError(s"Error computing feature drift - ${exception.getMessage}")) - case Success(driftSeriesFuture) => driftSeriesFuture.map { - driftSeries => + case Failure(exception) => + Future.successful(InternalServerError(s"Error computing feature drift - ${exception.getMessage}")) + case Success(driftSeriesFuture) => + driftSeriesFuture.map { driftSeries => val featureTs = convertTileDriftSeriesInfoToTimeSeries(driftSeries.head, metric) Ok(featureTs.asJson.noSpaces) - } + } } } else { // percentiles val maybeCurrentSummarySeries = driftStore.getSummarySeries(joinPath, startTs, endTs, Some(name)) - val maybeBaselineSummarySeries = driftStore.getSummarySeries(joinPath, startTs - window.millis, endTs - window.millis, Some(name)) + val maybeBaselineSummarySeries = + driftStore.getSummarySeries(joinPath, startTs - window.millis, endTs - window.millis, Some(name)) (maybeCurrentSummarySeries, maybeBaselineSummarySeries) match { - case (Failure(exceptionA), Failure(exceptionB)) => Future.successful(InternalServerError(s"Error computing feature percentiles for current + offset time window.\nCurrent window error: ${exceptionA.getMessage}\nOffset window error: ${exceptionB.getMessage}")) - case (_, Failure(exception)) => Future.successful(InternalServerError(s"Error computing feature percentiles for offset time window - ${exception.getMessage}")) - case (Failure(exception), _) => Future.successful(InternalServerError(s"Error computing feature percentiles for current time window - ${exception.getMessage}")) + case (Failure(exceptionA), Failure(exceptionB)) => + Future.successful(InternalServerError( + s"Error computing feature percentiles for current + offset time window.\nCurrent window error: ${exceptionA.getMessage}\nOffset window error: ${exceptionB.getMessage}")) + case (_, Failure(exception)) => + Future.successful( + InternalServerError( + s"Error computing feature percentiles for offset time window - ${exception.getMessage}")) + case (Failure(exception), _) => + Future.successful( + InternalServerError( + s"Error computing feature percentiles for current time window - ${exception.getMessage}")) case (Success(currentSummarySeriesFuture), Success(baselineSummarySeriesFuture)) => - Future.sequence(Seq(currentSummarySeriesFuture, baselineSummarySeriesFuture)).map { - merged => - val currentSummarySeries = merged.head - val baselineSummarySeries = merged.last - val currentFeatureTs = convertTileSummarySeriesToTimeSeries(currentSummarySeries.head, metric) - val baselineFeatureTs = convertTileSummarySeriesToTimeSeries(baselineSummarySeries.head, metric) - val comparedTsData = ComparedFeatureTimeSeries(name, baselineFeatureTs, currentFeatureTs) - Ok(comparedTsData.asJson.noSpaces) + Future.sequence(Seq(currentSummarySeriesFuture, baselineSummarySeriesFuture)).map { merged => + val currentSummarySeries = merged.head + val baselineSummarySeries = merged.last + val currentFeatureTs = convertTileSummarySeriesToTimeSeries(currentSummarySeries.head, metric) + val baselineFeatureTs = convertTileSummarySeriesToTimeSeries(baselineSummarySeries.head, metric) + val comparedTsData = ComparedFeatureTimeSeries(name, baselineFeatureTs, currentFeatureTs) + Ok(comparedTsData.asJson.noSpaces) } } } @@ -177,13 +199,16 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon } } - private def convertTileDriftSeriesInfoToTimeSeries(tileDriftSeries: TileDriftSeries, metric: Metric): FeatureTimeSeries = { + private def convertTileDriftSeriesInfoToTimeSeries(tileDriftSeries: TileDriftSeries, + metric: Metric): FeatureTimeSeries = { val lhsList = if (metric == NullMetric) { tileDriftSeries.nullRatioChangePercentSeries.asScala } else { // check if we have a numeric / categorical feature. If the percentile drift series has non-null doubles // then we have a numeric feature at hand - val isNumeric = tileDriftSeries.percentileDriftSeries.asScala != null && tileDriftSeries.percentileDriftSeries.asScala.exists(_ != null) + val isNumeric = + tileDriftSeries.percentileDriftSeries.asScala != null && tileDriftSeries.percentileDriftSeries.asScala + .exists(_ != null) if (isNumeric) tileDriftSeries.percentileDriftSeries.asScala else tileDriftSeries.histogramDriftSeries.asScala } @@ -194,7 +219,8 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon FeatureTimeSeries(tileDriftSeries.getKey.getColumn, points) } - private def convertTileSummarySeriesToTimeSeries(summarySeries: TileSummarySeries, metric: Metric): Seq[TimeSeriesPoint] = { + private def convertTileSummarySeriesToTimeSeries(summarySeries: TileSummarySeries, + metric: Metric): Seq[TimeSeriesPoint] = { if (metric == NullMetric) { summarySeries.nullCount.asScala.zip(summarySeries.timestamps.asScala).map { case (nullCount, ts) => TimeSeriesPoint(0, ts, nullValue = Some(nullCount.intValue())) @@ -202,7 +228,7 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon } else { // check if we have a numeric / categorical feature. If the percentile drift series has non-null doubles // then we have a numeric feature at hand - val isNumeric = summarySeries.percentiles.asScala != null && summarySeries.percentiles.asScala.exists(_ != null) + val isNumeric = summarySeries.percentiles.asScala != null && summarySeries.percentiles.asScala.exists(_ != null) if (isNumeric) { summarySeries.percentiles.asScala.zip(summarySeries.timestamps.asScala).flatMap { case (percentiles, ts) => @@ -210,8 +236,7 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon case (l, value) => TimeSeriesPoint(value, ts, Some(l)) } } - } - else { + } else { summarySeries.timestamps.asScala.zipWithIndex.flatMap { case (ts, idx) => summarySeries.histogram.asScala.map { @@ -238,10 +263,10 @@ object TimeSeriesController { def parseAlgorithm(algorithm: Option[String]): Option[DriftMetric] = { algorithm.map(_.toLowerCase) match { - case Some("psi") => Some(DriftMetric.PSI) + case Some("psi") => Some(DriftMetric.PSI) case Some("hellinger") => Some(DriftMetric.HELLINGER) - case Some("jsd") => Some(DriftMetric.JENSEN_SHANNON) - case _ => None + case Some("jsd") => Some(DriftMetric.JENSEN_SHANNON) + case _ => None } } @@ -251,7 +276,7 @@ object TimeSeriesController { case Some("drift") => Some(Drift) // case Some("skew") => Some(Skew) // case Some("ooc") => Some(Skew) - case _ => None + case _ => None } } diff --git a/hub/app/module/DriftStoreModule.scala b/hub/app/module/DriftStoreModule.scala index b8c12786e3..6456626375 100644 --- a/hub/app/module/DriftStoreModule.scala +++ b/hub/app/module/DriftStoreModule.scala @@ -4,7 +4,8 @@ import ai.chronon.online.KVStore import ai.chronon.online.stats.DriftStore import com.google.inject.AbstractModule -import javax.inject.{Inject, Provider} +import javax.inject.Inject +import javax.inject.Provider class DriftStoreModule extends AbstractModule { @@ -14,7 +15,7 @@ class DriftStoreModule extends AbstractModule { } } -class DriftStoreProvider @Inject()(kvStore: KVStore) extends Provider[DriftStore] { +class DriftStoreProvider @Inject() (kvStore: KVStore) extends Provider[DriftStore] { override def get(): DriftStore = { new DriftStore(kvStore) } diff --git a/hub/test/controllers/SearchControllerSpec.scala b/hub/test/controllers/SearchControllerSpec.scala index b188336fa9..95a0680674 100644 --- a/hub/test/controllers/SearchControllerSpec.scala +++ b/hub/test/controllers/SearchControllerSpec.scala @@ -4,7 +4,9 @@ import controllers.MockJoinService.mockJoinRegistry import io.circe._ import io.circe.generic.auto._ import io.circe.parser._ -import model.{GroupBy, Join, ListJoinResponse} +import model.GroupBy +import model.Join +import model.ListJoinResponse import org.mockito.Mockito.mock import org.mockito.Mockito.when import org.scalatest.EitherValues diff --git a/hub/test/controllers/TimeSeriesControllerSpec.scala b/hub/test/controllers/TimeSeriesControllerSpec.scala index 77d1f16806..c32d87226e 100644 --- a/hub/test/controllers/TimeSeriesControllerSpec.scala +++ b/hub/test/controllers/TimeSeriesControllerSpec.scala @@ -1,13 +1,16 @@ package controllers -import ai.chronon.api.{TileDriftSeries, TileSeriesKey, TileSummarySeries} +import ai.chronon.api.TileDriftSeries +import ai.chronon.api.TileSeriesKey +import ai.chronon.api.TileSummarySeries import ai.chronon.online.stats.DriftStore import io.circe._ import io.circe.generic.auto._ import io.circe.parser._ import model._ import org.mockito.ArgumentMatchers.any -import org.mockito.Mockito.{mock, when} +import org.mockito.Mockito.mock +import org.mockito.Mockito.when import org.scalatest.EitherValues import org.scalatestplus.play._ import play.api.http.Status.BAD_REQUEST @@ -16,13 +19,16 @@ import play.api.mvc._ import play.api.test.Helpers._ import play.api.test._ -import java.util.concurrent.TimeUnit -import scala.concurrent.{ExecutionContext, Future} -import scala.concurrent.duration.Duration -import scala.util.{Failure, Success, Try} import java.lang.{Double => JDouble} import java.lang.{Long => JLong} +import java.util.concurrent.TimeUnit +import scala.concurrent.ExecutionContext +import scala.concurrent.Future +import scala.concurrent.duration.Duration import scala.jdk.CollectionConverters._ +import scala.util.Failure +import scala.util.Success +import scala.util.Try class TimeSeriesControllerSpec extends PlaySpec with Results with EitherValues { diff --git a/online/src/main/scala/ai/chronon/online/HTTPKVStore.scala b/online/src/main/scala/ai/chronon/online/HTTPKVStore.scala index ff7fbdddf2..e050331db6 100644 --- a/online/src/main/scala/ai/chronon/online/HTTPKVStore.scala +++ b/online/src/main/scala/ai/chronon/online/HTTPKVStore.scala @@ -1,7 +1,6 @@ package ai.chronon.online import ai.chronon.online.KVStore.PutRequest -import io.circe.syntax._ import sttp.client3._ import sttp.model.StatusCode @@ -11,8 +10,8 @@ import scala.concurrent.Future // Hacky test kv store that we use to send objects to the in-memory KV store that lives in a different JVM (e.g spark -> hub) class HTTPKVStore(host: String = "localhost", port: Int = 9000) extends KVStore with Serializable { - val backend = HttpClientSyncBackend() - val baseUrl = s"http://$host:$port/api/v1/dataset" + val backend: SttpBackend[Identity, Any] = HttpClientSyncBackend() + val baseUrl: String = s"http://$host:$port/api/v1/dataset" override def multiGet(requests: collection.Seq[KVStore.GetRequest]): Future[collection.Seq[KVStore.GetResponse]] = ??? @@ -26,14 +25,13 @@ class HTTPKVStore(host: String = "localhost", port: Int = 9000) extends KVStore .header("Content-Type", "application/json") .body(jsonList(putRequests)) .send(backend) - }.map { - response => - response.code match { - case StatusCode.Ok => Seq(true) - case _ => - logger.error(s"HTTP multiPut failed with status ${response.code}: ${response.body}") - Seq(false) - } + }.map { response => + response.code match { + case StatusCode.Ok => Seq(true) + case _ => + logger.error(s"HTTP multiPut failed with status ${response.code}: ${response.body}") + Seq(false) + } } } } @@ -57,4 +55,3 @@ class HTTPKVStore(host: String = "localhost", port: Int = 9000) extends KVStore s"[ $requestsJson ]" } } - diff --git a/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala b/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala index 44c83e37c1..02812082e0 100644 --- a/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala +++ b/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala @@ -205,8 +205,25 @@ object DriftStore { def compactDeserializer: TDeserializer = new TDeserializer(new TBinaryProtocol.Factory()) // todo - drop this hard-coded list in favor of a well known list or exposing as part of summaries - val percentileLabels: Seq[String] = Seq("p0", "p5", "p10", "p15", "p20", - "p25", "p30", "p35", "p40", "p45", - "p50", "p55", "p60", "p65", "p70", - "p75", "p80", "p85", "p90", "p95", "p100") + val percentileLabels: Seq[String] = Seq("p0", + "p5", + "p10", + "p15", + "p20", + "p25", + "p30", + "p35", + "p40", + "p45", + "p50", + "p55", + "p60", + "p65", + "p70", + "p75", + "p80", + "p85", + "p90", + "p95", + "p100") } diff --git a/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemo.scala b/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemo.scala index b2e5015bf6..12622e222f 100644 --- a/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemo.scala +++ b/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemo.scala @@ -1,34 +1,20 @@ package ai.chronon.spark.scripts - -import ai.chronon import ai.chronon.api.ColorPrinter.ColorString import ai.chronon.api.Constants -import ai.chronon.api.DriftMetric import ai.chronon.api.Extensions.MetadataOps -import ai.chronon.api.Extensions.WindowOps -import ai.chronon.api.PartitionSpec -import ai.chronon.api.TileDriftSeries -import ai.chronon.api.TileSummarySeries -import ai.chronon.api.Window -import ai.chronon.online.{HTTPKVStore, KVStore} -import ai.chronon.online.stats.DriftStore +import ai.chronon.online.HTTPKVStore +import ai.chronon.online.KVStore import ai.chronon.spark.SparkSessionBuilder import ai.chronon.spark.TableUtils import ai.chronon.spark.stats.drift.Summarizer import ai.chronon.spark.stats.drift.SummaryUploader import ai.chronon.spark.stats.drift.scripts.PrepareData -import ai.chronon.spark.utils.InMemoryKvStore import ai.chronon.spark.utils.MockApi import org.rogach.scallop.ScallopConf import org.rogach.scallop.ScallopOption import org.slf4j.Logger import org.slf4j.LoggerFactory -import java.util.concurrent.TimeUnit -import scala.concurrent.Await -import scala.concurrent.duration.Duration -import scala.util.ScalaJavaConversions.IteratorOps - object ObservabilityDemo { @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) From 7fc0637fabd1efbfe0113536676a09256243f48a Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Tue, 26 Nov 2024 13:10:46 -0500 Subject: [PATCH 23/37] Update colPrefix pass through in drift store --- online/src/main/scala/ai/chronon/online/stats/DriftStore.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala b/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala index 02812082e0..733ebbfc6d 100644 --- a/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala +++ b/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala @@ -77,7 +77,7 @@ class DriftStore(kvStore: KVStore, columnPrefix: Option[String]): Future[Seq[TileSummaryInfo]] = { val serializer: TSerializer = compactSerializer - val tileKeyMap = tileKeysForJoin(joinConf, columnPrefix) + val tileKeyMap = tileKeysForJoin(joinConf, None, columnPrefix) val requestContextMap: Map[GetRequest, SummaryRequestContext] = tileKeyMap.flatMap { case (group, keys) => keys.map { key => From 4d78fa58716d53f50792e5bcb4ed832c73d6ab82 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Tue, 26 Nov 2024 14:42:12 -0500 Subject: [PATCH 24/37] Add details to join response + join get endpoint --- .../controllers/InMemKVStoreController.scala | 4 ++-- hub/app/controllers/JoinController.scala | 13 +++++++++++++ hub/app/model/Model.scala | 7 ++++++- hub/app/store/MonitoringModelStore.scala | 14 ++++++++++++-- hub/conf/routes | 1 + hub/test/controllers/JoinControllerSpec.scala | 18 ++++++++++++++++++ hub/test/controllers/ModelControllerSpec.scala | 2 +- .../controllers/SearchControllerSpec.scala | 2 +- 8 files changed, 54 insertions(+), 7 deletions(-) diff --git a/hub/app/controllers/InMemKVStoreController.scala b/hub/app/controllers/InMemKVStoreController.scala index 1101b0454c..6eda57c4cf 100644 --- a/hub/app/controllers/InMemKVStoreController.scala +++ b/hub/app/controllers/InMemKVStoreController.scala @@ -8,15 +8,15 @@ import io.circe.Encoder import io.circe.generic.semiauto.deriveCodec import io.circe.parser.decode import play.api.Logger +import play.api.mvc import play.api.mvc.BaseController import play.api.mvc.ControllerComponents +import play.api.mvc.RawBuffer import java.util.Base64 import javax.inject.Inject import scala.concurrent.ExecutionContext import scala.concurrent.Future -import play.api.mvc -import play.api.mvc.RawBuffer class InMemKVStoreController @Inject() (val controllerComponents: ControllerComponents, kvStore: KVStore)(implicit ec: ExecutionContext) diff --git a/hub/app/controllers/JoinController.scala b/hub/app/controllers/JoinController.scala index f65de2976c..383b6438d5 100644 --- a/hub/app/controllers/JoinController.scala +++ b/hub/app/controllers/JoinController.scala @@ -38,4 +38,17 @@ class JoinController @Inject() (val controllerComponents: ControllerComponents, Ok(json) } } + + /** + * Returns a specific join by name + */ + def get(name: String): Action[AnyContent] = { + Action { implicit request: Request[AnyContent] => + val maybeJoin = monitoringStore.getJoins.find(j => j.name.equalsIgnoreCase(name)) + maybeJoin match { + case None => NotFound(s"Join: $name wasn't found") + case Some(join) => Ok(join.asJson.noSpaces) + } + } + } } diff --git a/hub/app/model/Model.scala b/hub/app/model/Model.scala index f3e2d9f54c..eb523251d3 100644 --- a/hub/app/model/Model.scala +++ b/hub/app/model/Model.scala @@ -2,7 +2,12 @@ package model /** Captures some details related to ML models registered with Zipline to surface in the Hub UI */ case class GroupBy(name: String, features: Seq[String]) -case class Join(name: String, joinFeatures: Seq[String], groupBys: Seq[GroupBy]) +case class Join(name: String, + joinFeatures: Seq[String], + groupBys: Seq[GroupBy], + online: Boolean, + production: Boolean, + team: String) case class Model(name: String, join: Join, online: Boolean, production: Boolean, team: String, modelType: String) // 1.) metadataUpload: join -> map> diff --git a/hub/app/store/MonitoringModelStore.scala b/hub/app/store/MonitoringModelStore.scala index 9dd11d280c..89fbc5fa3e 100644 --- a/hub/app/store/MonitoringModelStore.scala +++ b/hub/app/store/MonitoringModelStore.scala @@ -60,7 +60,12 @@ class MonitoringModelStore(apiImpl: Api) { } val outputColumns = thriftJoin.outputColumnsByGroup.getOrElse("derivations", Array.empty) - val join = Join(thriftJoin.metaData.name, outputColumns, groupBys) + val join = Join(thriftJoin.metaData.name, + outputColumns, + groupBys, + thriftJoin.metaData.online, + thriftJoin.metaData.production, + thriftJoin.metaData.team) Option( Model(m.metaData.name, join, m.metaData.online, m.metaData.production, m.metaData.team, m.modelType.name())) } else { @@ -76,7 +81,12 @@ class MonitoringModelStore(apiImpl: Api) { } val outputColumns = thriftJoin.outputColumnsByGroup.getOrElse("derivations", Array.empty) - Join(thriftJoin.metaData.name, outputColumns, groupBys) + Join(thriftJoin.metaData.name, + outputColumns, + groupBys, + thriftJoin.metaData.online, + thriftJoin.metaData.production, + thriftJoin.metaData.team) } } diff --git a/hub/conf/routes b/hub/conf/routes index 8939447e6e..9ec36cc087 100644 --- a/hub/conf/routes +++ b/hub/conf/routes @@ -2,6 +2,7 @@ GET /api/v1/ping controllers.ApplicationController.ping() GET /api/v1/models controllers.ModelController.list(offset: Option[Int], limit: Option[Int]) GET /api/v1/joins controllers.JoinController.list(offset: Option[Int], limit: Option[Int]) +GET /api/v1/join/:name controllers.JoinController.get(name: String) GET /api/v1/search controllers.SearchController.search(term: String, offset: Option[Int], limit: Option[Int]) # model prediction & model drift - this is TBD at the moment diff --git a/hub/test/controllers/JoinControllerSpec.scala b/hub/test/controllers/JoinControllerSpec.scala index b6924cfdd1..8cf8ad7c38 100644 --- a/hub/test/controllers/JoinControllerSpec.scala +++ b/hub/test/controllers/JoinControllerSpec.scala @@ -4,6 +4,7 @@ import controllers.MockJoinService.mockJoinRegistry import io.circe._ import io.circe.generic.auto._ import io.circe.parser._ +import model.Join import model.ListJoinResponse import org.mockito.Mockito._ import org.scalatest.EitherValues @@ -36,6 +37,13 @@ class JoinControllerSpec extends PlaySpec with Results with EitherValues { status(result) mustBe BAD_REQUEST } + "send 404 on missing join" in { + when(mockedStore.getJoins).thenReturn(mockJoinRegistry) + + val result = controller.get("fake_join").apply(FakeRequest()) + status(result) mustBe NOT_FOUND + } + "send valid results on a correctly formed request" in { when(mockedStore.getJoins).thenReturn(mockJoinRegistry) @@ -61,5 +69,15 @@ class JoinControllerSpec extends PlaySpec with Results with EitherValues { items.length mustBe number items.map(_.name.toInt).toSet mustBe (startOffset until startOffset + number).toSet } + + "send valid join object on specific join lookup" in { + when(mockedStore.getJoins).thenReturn(mockJoinRegistry) + + val result = controller.get("10").apply(FakeRequest()) + status(result) mustBe OK + val bodyText = contentAsString(result) + val joinResponse: Either[Error, Join] = decode[Join](bodyText) + joinResponse.right.value.name mustBe "10" + } } } diff --git a/hub/test/controllers/ModelControllerSpec.scala b/hub/test/controllers/ModelControllerSpec.scala index 95b96ca24a..d68b536e39 100644 --- a/hub/test/controllers/ModelControllerSpec.scala +++ b/hub/test/controllers/ModelControllerSpec.scala @@ -71,7 +71,7 @@ class ModelControllerSpec extends PlaySpec with Results with EitherValues { object MockDataService { def generateMockModel(id: String): Model = { val groupBys = Seq(GroupBy("my_groupBy", Seq("g1", "g2"))) - val join = Join("my_join", Seq("ext_f1", "ext_f2", "d_1", "d2"), groupBys) + val join = Join("my_join", Seq("ext_f1", "ext_f2", "d_1", "d2"), groupBys, true, true, "my_team") Model(id, join, online = true, production = true, "my team", "XGBoost") } diff --git a/hub/test/controllers/SearchControllerSpec.scala b/hub/test/controllers/SearchControllerSpec.scala index 95a0680674..8ecca5c5c7 100644 --- a/hub/test/controllers/SearchControllerSpec.scala +++ b/hub/test/controllers/SearchControllerSpec.scala @@ -73,7 +73,7 @@ class SearchControllerSpec extends PlaySpec with Results with EitherValues { object MockJoinService { def generateMockJoin(id: String): Join = { val groupBys = Seq(GroupBy("my_groupBy", Seq("g1", "g2"))) - Join(id, Seq("ext_f1", "ext_f2", "d_1", "d2"), groupBys) + Join(id, Seq("ext_f1", "ext_f2", "d_1", "d2"), groupBys, true, true, "my_team") } val mockJoinRegistry: Seq[Join] = (0 until 100).map(i => generateMockJoin(i.toString)) From 7e7d2a17cb24b3d1b0f4be164033b4285a2fdbc3 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Wed, 27 Nov 2024 10:31:21 -0500 Subject: [PATCH 25/37] Rebase + comments --- .github/workflows/test_scala_no_spark.yaml | 7 ++++++- api/src/main/scala/ai/chronon/api/ColorPrinter.scala | 3 +++ build.sbt | 2 -- hub/app/model/Model.scala | 7 ------- spark/src/main/scala/ai/chronon/spark/TableUtils.scala | 6 +++--- 5 files changed, 12 insertions(+), 13 deletions(-) diff --git a/.github/workflows/test_scala_no_spark.yaml b/.github/workflows/test_scala_no_spark.yaml index 77edb859f9..00c4bc25f0 100644 --- a/.github/workflows/test_scala_no_spark.yaml +++ b/.github/workflows/test_scala_no_spark.yaml @@ -60,4 +60,9 @@ jobs: - name: Run api tests run: | - sbt "++ 2.12.18 api/test" \ No newline at end of file + sbt "++ 2.12.18 api/test" + + - name: Run hub tests + run: | + export SBT_OPTS="-Xmx8G -Xms2G" + sbt "++ 2.12.18 hub/test" diff --git a/api/src/main/scala/ai/chronon/api/ColorPrinter.scala b/api/src/main/scala/ai/chronon/api/ColorPrinter.scala index e779e3eaf1..4d1dc57c50 100644 --- a/api/src/main/scala/ai/chronon/api/ColorPrinter.scala +++ b/api/src/main/scala/ai/chronon/api/ColorPrinter.scala @@ -11,11 +11,14 @@ object ColorPrinter { private val ANSI_YELLOW = "\u001B[38;5;172m" // Muted Orange private val ANSI_GREEN = "\u001B[38;5;28m" // Forest green + private val BOLD = "\u001B[1m" + implicit class ColorString(val s: String) extends AnyVal { def red: String = s"$ANSI_RED$s$ANSI_RESET" def blue: String = s"$ANSI_BLUE$s$ANSI_RESET" def yellow: String = s"$ANSI_YELLOW$s$ANSI_RESET" def green: String = s"$ANSI_GREEN$s$ANSI_RESET" def low: String = s.toLowerCase + def highlight: String = s"$BOLD$ANSI_RED$s$ANSI_RESET" } } diff --git a/build.sbt b/build.sbt index 139945dcba..39aabe6364 100644 --- a/build.sbt +++ b/build.sbt @@ -135,8 +135,6 @@ lazy val online = project "com.github.ben-manes.caffeine" % "caffeine" % "3.1.8" ), libraryDependencies ++= jackson, - // we pull in circe to help us ser case classes like PutRequest without requiring annotations - libraryDependencies ++= circe, // dep needed for HTTPKvStore - yank when we rip this out libraryDependencies += "com.softwaremill.sttp.client3" %% "core" % "3.9.7", libraryDependencies ++= spark_all.map(_ % "provided"), diff --git a/hub/app/model/Model.scala b/hub/app/model/Model.scala index eb523251d3..a83c1c1803 100644 --- a/hub/app/model/Model.scala +++ b/hub/app/model/Model.scala @@ -10,13 +10,6 @@ case class Join(name: String, team: String) case class Model(name: String, join: Join, online: Boolean, production: Boolean, team: String, modelType: String) -// 1.) metadataUpload: join -> map> -// 2.) fetchJoinConf + listColumns: join => list -// 3.) (columns, start, end) -> list - -// 4.) 1:n/fetchTile: tileKey -> TileSummaries -// 5.) 1:n:n/compareTiles: TileSummaries, TileSummaries -> TileDrift -// 6.) Map[column, Seq[tileDrift]] -> TimeSeriesController /** Supported Metric types */ sealed trait MetricType diff --git a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala index d516d7a5f0..5595b46f59 100644 --- a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala @@ -327,9 +327,9 @@ case class TableUtils(sparkSession: SparkSession) { sql(creationSql) } catch { case _: TableAlreadyExistsException => - println(s"Table $tableName already exists, skipping creation") + logger.info(s"Table $tableName already exists, skipping creation") case e: Exception => - println(s"Failed to create table $tableName", e) + logger.error(s"Failed to create table $tableName", e) throw e } } @@ -357,7 +357,7 @@ case class TableUtils(sparkSession: SparkSession) { // so that an exception will be thrown below dfRearranged } - println(s"Repartitioning and writing into table $tableName".yellow) + logger.info(s"Repartitioning and writing into table $tableName".yellow) repartitionAndWrite(finalizedDf, tableName, saveMode, stats, sortByCols) } From 470fd9eb35b564a8106321a73654bb3340027a87 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Wed, 27 Nov 2024 10:42:11 -0500 Subject: [PATCH 26/37] Revert TableUtils for now --- spark/src/main/scala/ai/chronon/spark/TableUtils.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala index 5595b46f59..09aca39240 100644 --- a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala @@ -357,7 +357,6 @@ case class TableUtils(sparkSession: SparkSession) { // so that an exception will be thrown below dfRearranged } - logger.info(s"Repartitioning and writing into table $tableName".yellow) repartitionAndWrite(finalizedDf, tableName, saveMode, stats, sortByCols) } From 18e260c445cb5eed427b29f201a0aadb400ee4b7 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Wed, 27 Nov 2024 12:14:39 -0500 Subject: [PATCH 27/37] Swap to new uploader app and revert old code --- build.sbt | 5 +- docker-init/demo/build.sh | 1 + docker-init/demo/load_summaries.sh | 2 +- .../spark/scripts/ObservabilityDemo.scala | 18 ++- .../scripts/ObservabilityDemoDataLoader.scala | 114 ++++++++++++++++++ 5 files changed, 136 insertions(+), 4 deletions(-) create mode 100755 docker-init/demo/build.sh create mode 100644 spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemoDataLoader.scala diff --git a/build.sbt b/build.sbt index 39aabe6364..fd4852dfb4 100644 --- a/build.sbt +++ b/build.sbt @@ -264,7 +264,10 @@ lazy val hub = (project in file("hub")) excludeDependencies ++= Seq( ExclusionRule(organization = "org.slf4j", name = "slf4j-log4j12"), ExclusionRule(organization = "log4j", name = "log4j"), - ExclusionRule(organization = "org.apache.logging.log4j", name = "log4j-to-slf4j") + ExclusionRule(organization = "org.apache.logging.log4j", name = "log4j-to-slf4j"), + ExclusionRule("org.apache.logging.log4j", "log4j-slf4j-impl"), + ExclusionRule("org.apache.logging.log4j", "log4j-core"), + ExclusionRule("org.apache.logging.log4j", "log4j-api") ), // Ensure consistent versions of logging libraries dependencyOverrides ++= Seq( diff --git a/docker-init/demo/build.sh b/docker-init/demo/build.sh new file mode 100755 index 0000000000..5627dac2f5 --- /dev/null +++ b/docker-init/demo/build.sh @@ -0,0 +1 @@ +docker build -t obs . \ No newline at end of file diff --git a/docker-init/demo/load_summaries.sh b/docker-init/demo/load_summaries.sh index 61b4d9db95..15bc3681a0 100755 --- a/docker-init/demo/load_summaries.sh +++ b/docker-init/demo/load_summaries.sh @@ -8,5 +8,5 @@ docker-compose -f docker-init/compose.yaml exec app /opt/spark/bin/spark-submit --driver-class-path "/opt/spark/jars/*:/app/cli/*" \ --conf "spark.driver.host=localhost" \ --conf "spark.driver.bindAddress=0.0.0.0" \ - --class ai.chronon.spark.scripts.ObservabilityDemo \ + --class ai.chronon.spark.scripts.ObservabilityDemoDataLoader \ /app/cli/spark.jar diff --git a/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemo.scala b/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemo.scala index 12622e222f..064b69700f 100644 --- a/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemo.scala +++ b/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemo.scala @@ -1,20 +1,34 @@ package ai.chronon.spark.scripts + +import ai.chronon import ai.chronon.api.ColorPrinter.ColorString import ai.chronon.api.Constants +import ai.chronon.api.DriftMetric import ai.chronon.api.Extensions.MetadataOps -import ai.chronon.online.HTTPKVStore +import ai.chronon.api.Extensions.WindowOps +import ai.chronon.api.PartitionSpec +import ai.chronon.api.TileDriftSeries +import ai.chronon.api.TileSummarySeries +import ai.chronon.api.Window import ai.chronon.online.KVStore +import ai.chronon.online.stats.DriftStore import ai.chronon.spark.SparkSessionBuilder import ai.chronon.spark.TableUtils import ai.chronon.spark.stats.drift.Summarizer import ai.chronon.spark.stats.drift.SummaryUploader import ai.chronon.spark.stats.drift.scripts.PrepareData +import ai.chronon.spark.utils.InMemoryKvStore import ai.chronon.spark.utils.MockApi import org.rogach.scallop.ScallopConf import org.rogach.scallop.ScallopOption import org.slf4j.Logger import org.slf4j.LoggerFactory +import java.util.concurrent.TimeUnit +import scala.concurrent.Await +import scala.concurrent.duration.Duration +import scala.util.ScalaJavaConversions.IteratorOps + object ObservabilityDemo { @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) @@ -42,7 +56,7 @@ object ObservabilityDemo { // mock api impl for online fetching and uploading val kvStoreFunc: () => KVStore = () => { // cannot reuse the variable - or serialization error - val result = new HTTPKVStore() + val result = InMemoryKvStore.build(namespace, () => null) result } val api = new MockApi(kvStoreFunc, namespace) diff --git a/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemoDataLoader.scala b/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemoDataLoader.scala new file mode 100644 index 0000000000..65275b8d94 --- /dev/null +++ b/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemoDataLoader.scala @@ -0,0 +1,114 @@ +package ai.chronon.spark.scripts + +import ai.chronon.api.ColorPrinter.ColorString +import ai.chronon.api.Constants +import ai.chronon.api.Extensions.MetadataOps +import ai.chronon.online.{HTTPKVStore, KVStore} +import ai.chronon.spark.{SparkSessionBuilder, TableUtils} +import ai.chronon.spark.stats.drift.{Summarizer, SummaryUploader} +import ai.chronon.spark.stats.drift.scripts.PrepareData +import ai.chronon.spark.utils.{InMemoryKvStore, MockApi} +import org.rogach.scallop.{ScallopConf, ScallopOption} +import org.slf4j.{Logger, LoggerFactory} + +object ObservabilityDemoDataLoader { + @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) + + def time(message: String)(block: => Unit): Unit = { + logger.info(s"$message..".yellow) + val start = System.currentTimeMillis() + block + val end = System.currentTimeMillis() + logger.info(s"$message took ${end - start} ms".green) + } + + class Conf(arguments: Seq[String]) extends ScallopConf(arguments) { + val startDs: ScallopOption[String] = opt[String]( + name = "start-ds", + default = Some("2023-01-01"), + descr = "Start date in YYYY-MM-DD format" + ) + + val endDs: ScallopOption[String] = opt[String]( + name = "end-ds", + default = Some("2023-02-30"), + descr = "End date in YYYY-MM-DD format" + ) + + val rowCount: ScallopOption[Int] = opt[Int]( + name = "row-count", + default = Some(700000), + descr = "Number of rows to generate" + ) + + val namespace: ScallopOption[String] = opt[String]( + name = "namespace", + default = Some("observability_demo"), + descr = "Namespace for the demo" + ) + + verify() + } + + def main(args: Array[String]): Unit = { + + val config = new Conf(args) + val startDs = config.startDs() + val endDs = config.endDs() + val rowCount = config.rowCount() + val namespace = config.namespace() + + val spark = SparkSessionBuilder.build(namespace, local = true) + implicit val tableUtils: TableUtils = TableUtils(spark) + tableUtils.createDatabase(namespace) + + // generate anomalous data (join output) + val prepareData = PrepareData(namespace) + val join = prepareData.generateAnomalousFraudJoin + + time("Preparing data") { + val df = prepareData.generateFraudSampleData(rowCount, startDs, endDs, join.metaData.loggedTable) + df.show(10, truncate = false) + } + + // mock api impl for online fetching and uploading + val inMemKvStoreFunc: () => KVStore = () => { + // cannot reuse the variable - or serialization error + val result = InMemoryKvStore.build(namespace, () => null) + result + } + val inMemoryApi = new MockApi(inMemKvStoreFunc, namespace) + + time("Summarizing data") { + // compute summary table and packed table (for uploading) + Summarizer.compute(inMemoryApi, join.metaData, ds = endDs, useLogs = true) + } + + val packedTable = join.metaData.packedSummaryTable + + // create necessary tables in kvstore - we now publish to the HTTP KV store as we need this available to the Hub + val httpKvStoreFunc: () => KVStore = () => { + // cannot reuse the variable - or serialization error + val result = new HTTPKVStore() + result + } + val hubApi = new MockApi(httpKvStoreFunc, namespace) + + val kvStore = hubApi.genKvStore + kvStore.create(Constants.MetadataDataset) + kvStore.create(Constants.TiledSummaryDataset) + + // upload join conf + hubApi.buildFetcher().putJoinConf(join) + + time("Uploading summaries") { + val uploader = new SummaryUploader(tableUtils.loadTable(packedTable), hubApi) + uploader.run() + } + + println("Done uploading summaries! \uD83E\uDD73".green) + // clean up spark session and force jvm exit + spark.stop() + System.exit(0) + } +} From de044c4eab538e98d5bf9c698ae1a5ad98275ea2 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Wed, 27 Nov 2024 12:14:59 -0500 Subject: [PATCH 28/37] Downgrade in mem controller log to debug --- hub/app/controllers/InMemKVStoreController.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hub/app/controllers/InMemKVStoreController.scala b/hub/app/controllers/InMemKVStoreController.scala index 6eda57c4cf..97768e86b8 100644 --- a/hub/app/controllers/InMemKVStoreController.scala +++ b/hub/app/controllers/InMemKVStoreController.scala @@ -32,7 +32,7 @@ class InMemKVStoreController @Inject() (val controllerComponents: ControllerComp case Some(bytes) => decode[Array[PutRequest]](bytes.utf8String) match { case Right(putRequests) => - logger.info(s"Attempting a bulkPut with ${putRequests.length} items") + logger.debug(s"Attempting a bulkPut with ${putRequests.length} items") val resultFuture = kvStore.multiPut(putRequests) resultFuture.map { responses => if (responses.contains(false)) { From dfc4e1fdd623f0465c82a952981bd1c05352a177 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Wed, 27 Nov 2024 12:15:28 -0500 Subject: [PATCH 29/37] style: Apply scalafix and scalafmt changes --- .../scripts/ObservabilityDemoDataLoader.scala | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemoDataLoader.scala b/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemoDataLoader.scala index 65275b8d94..f317488273 100644 --- a/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemoDataLoader.scala +++ b/spark/src/main/scala/ai/chronon/spark/scripts/ObservabilityDemoDataLoader.scala @@ -3,13 +3,19 @@ package ai.chronon.spark.scripts import ai.chronon.api.ColorPrinter.ColorString import ai.chronon.api.Constants import ai.chronon.api.Extensions.MetadataOps -import ai.chronon.online.{HTTPKVStore, KVStore} -import ai.chronon.spark.{SparkSessionBuilder, TableUtils} -import ai.chronon.spark.stats.drift.{Summarizer, SummaryUploader} +import ai.chronon.online.HTTPKVStore +import ai.chronon.online.KVStore +import ai.chronon.spark.SparkSessionBuilder +import ai.chronon.spark.TableUtils +import ai.chronon.spark.stats.drift.Summarizer +import ai.chronon.spark.stats.drift.SummaryUploader import ai.chronon.spark.stats.drift.scripts.PrepareData -import ai.chronon.spark.utils.{InMemoryKvStore, MockApi} -import org.rogach.scallop.{ScallopConf, ScallopOption} -import org.slf4j.{Logger, LoggerFactory} +import ai.chronon.spark.utils.InMemoryKvStore +import ai.chronon.spark.utils.MockApi +import org.rogach.scallop.ScallopConf +import org.rogach.scallop.ScallopOption +import org.slf4j.Logger +import org.slf4j.LoggerFactory object ObservabilityDemoDataLoader { @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) From 6ab315c8f1f819cce36456eb12510b9a136d760e Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Wed, 27 Nov 2024 12:43:56 -0500 Subject: [PATCH 30/37] Handle empty responses --- hub/app/controllers/TimeSeriesController.scala | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/hub/app/controllers/TimeSeriesController.scala b/hub/app/controllers/TimeSeriesController.scala index ca5207c688..cae7ca25c9 100644 --- a/hub/app/controllers/TimeSeriesController.scala +++ b/hub/app/controllers/TimeSeriesController.scala @@ -188,8 +188,14 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon Future.sequence(Seq(currentSummarySeriesFuture, baselineSummarySeriesFuture)).map { merged => val currentSummarySeries = merged.head val baselineSummarySeries = merged.last - val currentFeatureTs = convertTileSummarySeriesToTimeSeries(currentSummarySeries.head, metric) - val baselineFeatureTs = convertTileSummarySeriesToTimeSeries(baselineSummarySeries.head, metric) + val currentFeatureTs = { + if (currentSummarySeries.isEmpty) Seq.empty + else convertTileSummarySeriesToTimeSeries(currentSummarySeries.head, metric) + } + val baselineFeatureTs = { + if (baselineSummarySeries.isEmpty) Seq.empty + else convertTileSummarySeriesToTimeSeries(baselineSummarySeries.head, metric) + } val comparedTsData = ComparedFeatureTimeSeries(name, baselineFeatureTs, currentFeatureTs) Ok(comparedTsData.asJson.noSpaces) } From 62687b97a83c268fd3554fd1978035b42edf975c Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Wed, 27 Nov 2024 15:15:23 -0500 Subject: [PATCH 31/37] Remove redundant log4j props file --- docker-init/demo/log4j2.properties | 17 ----------------- 1 file changed, 17 deletions(-) delete mode 100644 docker-init/demo/log4j2.properties diff --git a/docker-init/demo/log4j2.properties b/docker-init/demo/log4j2.properties deleted file mode 100644 index a0167384ee..0000000000 --- a/docker-init/demo/log4j2.properties +++ /dev/null @@ -1,17 +0,0 @@ -# Root logger -rootLogger.level = ERROR -rootLogger.appenderRef.console.ref = console - -# Console appender configuration -appender.console.type = Console -appender.console.name = console -appender.console.target = SYSTEM_OUT -appender.console.layout.type = PatternLayout -appender.console.layout.pattern = %yellow{%d{yyyy/MM/dd HH:mm:ss}} %highlight{%-5level} %green{%file:%line} - %message%n - -# Configure specific logger -logger.chronon.name = ai.chronon -logger.chronon.level = info - -# Configure colors -appender.console.layout.disableAnsi = false \ No newline at end of file From 5316680a3ec373dd1a34eca8b675875ae009dfd1 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Wed, 27 Nov 2024 15:16:00 -0500 Subject: [PATCH 32/37] Use thread locals for thrift serializers --- .../ai/chronon/online/stats/DriftStore.scala | 20 ++++++++++++------- .../spark/stats/drift/Summarizer.scala | 5 +++-- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala b/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala index 733ebbfc6d..b1b935a04f 100644 --- a/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala +++ b/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala @@ -12,8 +12,8 @@ import ai.chronon.api.thrift.protocol.TProtocolFactory import ai.chronon.online.KVStore import ai.chronon.online.KVStore.GetRequest import ai.chronon.online.MetadataStore -import ai.chronon.online.stats.DriftStore.compactDeserializer -import ai.chronon.online.stats.DriftStore.compactSerializer +import ai.chronon.online.stats.DriftStore.binaryDeserializer +import ai.chronon.online.stats.DriftStore.binarySerializer import java.io.Serializable import scala.concurrent.Future @@ -52,8 +52,6 @@ class DriftStore(kvStore: KVStore, } } - private val deserializer: TDeserializer = compactDeserializer - private case class SummaryRequestContext(request: GetRequest, tileKey: TileKey, groupName: String) private case class SummaryResponseContext(summaries: Array[(TileSummary, Long)], tileKey: TileKey, groupName: String) @@ -76,7 +74,7 @@ class DriftStore(kvStore: KVStore, endMs: Option[Long], columnPrefix: Option[String]): Future[Seq[TileSummaryInfo]] = { - val serializer: TSerializer = compactSerializer + val serializer: TSerializer = binarySerializer.get() val tileKeyMap = tileKeysForJoin(joinConf, None, columnPrefix) val requestContextMap: Map[GetRequest, SummaryRequestContext] = tileKeyMap.flatMap { case (group, keys) => @@ -90,6 +88,7 @@ class DriftStore(kvStore: KVStore, val responseFuture = kvStore.multiGet(requestContextMap.keys.toSeq) responseFuture.map { responses => + val deserializer = binaryDeserializer.get() // deserialize the responses and surround with context val responseContextTries: Seq[Try[SummaryResponseContext]] = responses.map { response => val valuesTry = response.values @@ -200,9 +199,16 @@ object DriftStore { class SerializableSerializer(factory: TProtocolFactory) extends TSerializer(factory) with Serializable // crazy bug in compact protocol - do not change to compact - def compactSerializer: SerializableSerializer = new SerializableSerializer(new TBinaryProtocol.Factory()) - def compactDeserializer: TDeserializer = new TDeserializer(new TBinaryProtocol.Factory()) + @transient + lazy val binarySerializer: ThreadLocal[TSerializer] = new ThreadLocal[TSerializer] { + override def initialValue(): TSerializer = new TSerializer(new TBinaryProtocol.Factory()) + } + + @transient + lazy val binaryDeserializer: ThreadLocal[TDeserializer] = new ThreadLocal[TDeserializer] { + override def initialValue(): TDeserializer = new TDeserializer(new TBinaryProtocol.Factory()) + } // todo - drop this hard-coded list in favor of a well known list or exposing as part of summaries val percentileLabels: Seq[String] = Seq("p0", diff --git a/spark/src/main/scala/ai/chronon/spark/stats/drift/Summarizer.scala b/spark/src/main/scala/ai/chronon/spark/stats/drift/Summarizer.scala index 856850b892..2874f3b907 100644 --- a/spark/src/main/scala/ai/chronon/spark/stats/drift/Summarizer.scala +++ b/spark/src/main/scala/ai/chronon/spark/stats/drift/Summarizer.scala @@ -6,7 +6,7 @@ import ai.chronon.api._ import ai.chronon.online.Api import ai.chronon.online.KVStore.GetRequest import ai.chronon.online.KVStore.PutRequest -import ai.chronon.online.stats.DriftStore.compactSerializer +import ai.chronon.online.stats.DriftStore.binarySerializer import ai.chronon.spark.TableUtils import ai.chronon.spark.stats.drift.Expressions.CardinalityExpression import ai.chronon.spark.stats.drift.Expressions.SummaryExpression @@ -322,9 +322,10 @@ class SummaryPacker(confPath: String, val func: sql.Row => Seq[TileRow] = Expressions.summaryPopulatorFunc(summaryExpressions, df.schema, keyBuilder, tu.partitionColumn) - val serializer = compactSerializer val packedRdd: RDD[sql.Row] = df.rdd.flatMap(func).map { tileRow => // pack into bytes + val serializer = binarySerializer.get() + val partition = tileRow.partition val timestamp = tileRow.tileTs val summaries = tileRow.summaries From 92f0d11aa67c3c946a3e376ef9a9011892392bf9 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Wed, 27 Nov 2024 16:01:35 -0500 Subject: [PATCH 33/37] Rebase + comments --- build.sbt | 1 + .../main/scala/ai/chronon/spark/scripts/DataServer.scala | 7 +------ 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/build.sbt b/build.sbt index fd4852dfb4..c80ede6a12 100644 --- a/build.sbt +++ b/build.sbt @@ -80,6 +80,7 @@ val jackson = Seq( "com.fasterxml.jackson.module" %% "jackson-module-scala" ).map(_ % jackson_2_15) +// Circe is used to ser / deser case class payloads for the Hub Play webservice val circe = Seq( "io.circe" %% "circe-core", "io.circe" %% "circe-generic", diff --git a/spark/src/main/scala/ai/chronon/spark/scripts/DataServer.scala b/spark/src/main/scala/ai/chronon/spark/scripts/DataServer.scala index fba36f4d0c..cf935fd334 100644 --- a/spark/src/main/scala/ai/chronon/spark/scripts/DataServer.scala +++ b/spark/src/main/scala/ai/chronon/spark/scripts/DataServer.scala @@ -35,15 +35,10 @@ class DataServer(driftSeries: Seq[TileDriftSeries], summarySeries: Seq[TileSumma ctx.flush() } - private val serializer: ThreadLocal[SerializableSerializer] = - ThreadLocal.withInitial(new Supplier[SerializableSerializer] { - override def get(): SerializableSerializer = DriftStore.compactSerializer - }) - private def convertToBytesMap[T <: TBase[_, _]: Manifest: ClassTag]( series: T, keyF: T => TileSeriesKey): Map[String, String] = { - val serializerInstance = serializer.get() + val serializerInstance = DriftStore.binarySerializer.get() val encoder = Base64.getEncoder val keyBytes = serializerInstance.serialize(keyF(series)) val valueBytes = serializerInstance.serialize(series) From 4965c8a292407b7f901dd0b8c25647e0337aacf8 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Wed, 27 Nov 2024 16:02:11 -0500 Subject: [PATCH 34/37] style: Apply scalafix and scalafmt changes --- spark/src/main/scala/ai/chronon/spark/scripts/DataServer.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/spark/src/main/scala/ai/chronon/spark/scripts/DataServer.scala b/spark/src/main/scala/ai/chronon/spark/scripts/DataServer.scala index cf935fd334..afd194a7d2 100644 --- a/spark/src/main/scala/ai/chronon/spark/scripts/DataServer.scala +++ b/spark/src/main/scala/ai/chronon/spark/scripts/DataServer.scala @@ -5,7 +5,6 @@ import ai.chronon.api.TileSeriesKey import ai.chronon.api.TileSummarySeries import ai.chronon.api.thrift.TBase import ai.chronon.online.stats.DriftStore -import ai.chronon.online.stats.DriftStore.SerializableSerializer import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.databind.SerializationFeature import com.fasterxml.jackson.module.scala.DefaultScalaModule @@ -19,7 +18,6 @@ import io.netty.handler.codec.http._ import io.netty.util.CharsetUtil import java.util.Base64 -import java.util.function.Supplier import scala.reflect.ClassTag class DataServer(driftSeries: Seq[TileDriftSeries], summarySeries: Seq[TileSummarySeries], port: Int = 8181) { From 06ccac71f7ebede1e82bd497b9fcf4e0b1a14d76 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Wed, 27 Nov 2024 17:49:51 -0500 Subject: [PATCH 35/37] Add breaks method --- .../controllers/TimeSeriesController.scala | 2 +- .../TimeSeriesControllerSpec.scala | 4 ++-- .../ai/chronon/online/stats/DriftStore.scala | 22 +------------------ 3 files changed, 4 insertions(+), 24 deletions(-) diff --git a/hub/app/controllers/TimeSeriesController.scala b/hub/app/controllers/TimeSeriesController.scala index cae7ca25c9..68c0d43c5f 100644 --- a/hub/app/controllers/TimeSeriesController.scala +++ b/hub/app/controllers/TimeSeriesController.scala @@ -238,7 +238,7 @@ class TimeSeriesController @Inject() (val controllerComponents: ControllerCompon if (isNumeric) { summarySeries.percentiles.asScala.zip(summarySeries.timestamps.asScala).flatMap { case (percentiles, ts) => - DriftStore.percentileLabels.zip(percentiles.asScala).map { + DriftStore.breaks(20).zip(percentiles.asScala).map { case (l, value) => TimeSeriesPoint(value, ts, Some(l)) } } diff --git a/hub/test/controllers/TimeSeriesControllerSpec.scala b/hub/test/controllers/TimeSeriesControllerSpec.scala index c32d87226e..ce47d2811b 100644 --- a/hub/test/controllers/TimeSeriesControllerSpec.scala +++ b/hub/test/controllers/TimeSeriesControllerSpec.scala @@ -209,7 +209,7 @@ class TimeSeriesControllerSpec extends PlaySpec with Results with EitherValues { } // expect one entry per percentile for each time series point - val expectedLength = DriftStore.percentileLabels.length * expectedHours(startTs, endTs) + val expectedLength = DriftStore.breaks(20).length * expectedHours(startTs, endTs) response.current.length mustBe expectedLength response.baseline.length mustBe expectedLength } @@ -299,7 +299,7 @@ class TimeSeriesControllerSpec extends PlaySpec with Results with EitherValues { if (isNumeric) { val percentileList = timestamps.map { _ => - List.fill(DriftStore.percentileLabels.length)(JDouble.valueOf(0.12)).asJava + List.fill(DriftStore.breaks(20).length)(JDouble.valueOf(0.12)).asJava }.asJava tileSummarySeries.setPercentiles(percentileList) } else { diff --git a/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala b/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala index b1b935a04f..84431aa10a 100644 --- a/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala +++ b/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala @@ -211,25 +211,5 @@ object DriftStore { } // todo - drop this hard-coded list in favor of a well known list or exposing as part of summaries - val percentileLabels: Seq[String] = Seq("p0", - "p5", - "p10", - "p15", - "p20", - "p25", - "p30", - "p35", - "p40", - "p45", - "p50", - "p55", - "p60", - "p65", - "p70", - "p75", - "p80", - "p85", - "p90", - "p95", - "p100") + def breaks(count: Int): Seq[String] = (0 to count).map(_ * (100/count)).map("p" + _.toString) } From 3089feb356ad3787ae684db26bdd5cfd67a8c187 Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Wed, 27 Nov 2024 17:50:40 -0500 Subject: [PATCH 36/37] style: Apply scalafix and scalafmt changes --- online/src/main/scala/ai/chronon/online/stats/DriftStore.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala b/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala index 84431aa10a..f9e00aa587 100644 --- a/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala +++ b/online/src/main/scala/ai/chronon/online/stats/DriftStore.scala @@ -211,5 +211,5 @@ object DriftStore { } // todo - drop this hard-coded list in favor of a well known list or exposing as part of summaries - def breaks(count: Int): Seq[String] = (0 to count).map(_ * (100/count)).map("p" + _.toString) + def breaks(count: Int): Seq[String] = (0 to count).map(_ * (100 / count)).map("p" + _.toString) } From 4930a984a1f65e760d2f8a2c791dd4b6272e267a Mon Sep 17 00:00:00 2001 From: Piyush Narang Date: Wed, 27 Nov 2024 18:09:48 -0500 Subject: [PATCH 37/37] Wrap team in optional --- hub/app/model/Model.scala | 2 +- hub/app/store/MonitoringModelStore.scala | 4 ++-- hub/test/controllers/ModelControllerSpec.scala | 2 +- hub/test/controllers/SearchControllerSpec.scala | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/hub/app/model/Model.scala b/hub/app/model/Model.scala index a83c1c1803..39aa971892 100644 --- a/hub/app/model/Model.scala +++ b/hub/app/model/Model.scala @@ -7,7 +7,7 @@ case class Join(name: String, groupBys: Seq[GroupBy], online: Boolean, production: Boolean, - team: String) + team: Option[String]) case class Model(name: String, join: Join, online: Boolean, production: Boolean, team: String, modelType: String) /** Supported Metric types */ diff --git a/hub/app/store/MonitoringModelStore.scala b/hub/app/store/MonitoringModelStore.scala index 89fbc5fa3e..30b62e0937 100644 --- a/hub/app/store/MonitoringModelStore.scala +++ b/hub/app/store/MonitoringModelStore.scala @@ -65,7 +65,7 @@ class MonitoringModelStore(apiImpl: Api) { groupBys, thriftJoin.metaData.online, thriftJoin.metaData.production, - thriftJoin.metaData.team) + Option(thriftJoin.metaData.team)) Option( Model(m.metaData.name, join, m.metaData.online, m.metaData.production, m.metaData.team, m.modelType.name())) } else { @@ -86,7 +86,7 @@ class MonitoringModelStore(apiImpl: Api) { groupBys, thriftJoin.metaData.online, thriftJoin.metaData.production, - thriftJoin.metaData.team) + Option(thriftJoin.metaData.team)) } } diff --git a/hub/test/controllers/ModelControllerSpec.scala b/hub/test/controllers/ModelControllerSpec.scala index d68b536e39..2ff2b4ee7f 100644 --- a/hub/test/controllers/ModelControllerSpec.scala +++ b/hub/test/controllers/ModelControllerSpec.scala @@ -71,7 +71,7 @@ class ModelControllerSpec extends PlaySpec with Results with EitherValues { object MockDataService { def generateMockModel(id: String): Model = { val groupBys = Seq(GroupBy("my_groupBy", Seq("g1", "g2"))) - val join = Join("my_join", Seq("ext_f1", "ext_f2", "d_1", "d2"), groupBys, true, true, "my_team") + val join = Join("my_join", Seq("ext_f1", "ext_f2", "d_1", "d2"), groupBys, true, true, Some("my_team")) Model(id, join, online = true, production = true, "my team", "XGBoost") } diff --git a/hub/test/controllers/SearchControllerSpec.scala b/hub/test/controllers/SearchControllerSpec.scala index 8ecca5c5c7..07dcbbdb7d 100644 --- a/hub/test/controllers/SearchControllerSpec.scala +++ b/hub/test/controllers/SearchControllerSpec.scala @@ -73,7 +73,7 @@ class SearchControllerSpec extends PlaySpec with Results with EitherValues { object MockJoinService { def generateMockJoin(id: String): Join = { val groupBys = Seq(GroupBy("my_groupBy", Seq("g1", "g2"))) - Join(id, Seq("ext_f1", "ext_f2", "d_1", "d2"), groupBys, true, true, "my_team") + Join(id, Seq("ext_f1", "ext_f2", "d_1", "d2"), groupBys, true, true, Some("my_team")) } val mockJoinRegistry: Seq[Join] = (0 until 100).map(i => generateMockJoin(i.toString))