Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 28 additions & 24 deletions online/src/main/scala/ai/chronon/online/CatalystUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import java.util.concurrent.ConcurrentHashMap
import java.util.function
import scala.collection.Seq
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

object CatalystUtil {
private class IteratorWrapper[T] extends Iterator[T] {
Expand Down Expand Up @@ -109,7 +110,7 @@ class PoolMap[Key, Value](createFunc: Key => Value, maxSize: Int = 100, initialS
class PooledCatalystUtil(expressions: collection.Seq[(String, String)], inputSchema: StructType) {
private val poolKey = PoolKey(expressions, inputSchema)
private val cuPool = poolMap.getPool(PoolKey(expressions, inputSchema))
def performSql(values: Map[String, Any]): Option[Map[String, Any]] =
def performSql(values: Map[String, Any]): Seq[Map[String, Any]] =
poolMap.performWithValue(poolKey, cuPool) { _.performSql(values) }
def outputChrononSchema: Array[(String, DataType)] =
poolMap.performWithValue(poolKey, cuPool) { _.outputChrononSchema }
Expand All @@ -136,34 +137,34 @@ class CatalystUtil(inputSchema: StructType,
private val inputEncoder = SparkInternalRowConversions.to(inputSparkSchema)
private val inputArrEncoder = SparkInternalRowConversions.to(inputSparkSchema, false)

private val (transformFunc: (InternalRow => Option[InternalRow]), outputSparkSchema: types.StructType) = initialize()
private val (transformFunc: (InternalRow => Seq[InternalRow]), outputSparkSchema: types.StructType) = initialize()

private lazy val outputArrDecoder = SparkInternalRowConversions.from(outputSparkSchema, false)
@transient lazy val outputChrononSchema: Array[(String, DataType)] =
SparkConversions.toChrononSchema(outputSparkSchema)
private val outputDecoder = SparkInternalRowConversions.from(outputSparkSchema)

def performSql(values: Array[Any]): Option[Array[Any]] = {
def performSql(values: Array[Any]): Seq[Array[Any]] = {
val internalRow = inputArrEncoder(values).asInstanceOf[InternalRow]
val resultRowOpt = transformFunc(internalRow)
val outputVal = resultRowOpt.map(resultRow => outputArrDecoder(resultRow))
val resultRowSeq = transformFunc(internalRow)
val outputVal = resultRowSeq.map(resultRow => outputArrDecoder(resultRow))
outputVal.map(_.asInstanceOf[Array[Any]])
}

def performSql(values: Map[String, Any]): Option[Map[String, Any]] = {
def performSql(values: Map[String, Any]): Seq[Map[String, Any]] = {
val internalRow = inputEncoder(values).asInstanceOf[InternalRow]
performSql(internalRow)
}

def performSql(row: InternalRow): Option[Map[String, Any]] = {
def performSql(row: InternalRow): Seq[Map[String, Any]] = {
val resultRowMaybe = transformFunc(row)
val outputVal = resultRowMaybe.map(resultRow => outputDecoder(resultRow))
outputVal.map(_.asInstanceOf[Map[String, Any]])
}

def getOutputSparkSchema: types.StructType = outputSparkSchema

private def initialize(): (InternalRow => Option[InternalRow], types.StructType) = {
private def initialize(): (InternalRow => Seq[InternalRow], types.StructType) = {
val session = CatalystUtil.session

// run through and execute the setup statements
Expand All @@ -189,32 +190,34 @@ class CatalystUtil(inputSchema: StructType,
val filteredDf = whereClauseOpt.map(df.where(_)).getOrElse(df)

// extract transform function from the df spark plan
val func: InternalRow => Option[InternalRow] = filteredDf.queryExecution.executedPlan match {
val func: InternalRow => ArrayBuffer[InternalRow] = filteredDf.queryExecution.executedPlan match {
case whc: WholeStageCodegenExec => {
val (ctx, cleanedSource) = whc.doCodeGen()
val (clazz, _) = CodeGenerator.compile(cleanedSource)
val references = ctx.references.toArray
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
val iteratorWrapper: IteratorWrapper[InternalRow] = new IteratorWrapper[InternalRow]
buffer.init(0, Array(iteratorWrapper))
def codegenFunc(row: InternalRow): Option[InternalRow] = {
def codegenFunc(row: InternalRow): ArrayBuffer[InternalRow] = {
iteratorWrapper.put(row)
val result = ArrayBuffer.empty[InternalRow]
while (buffer.hasNext) {
return Some(buffer.next())
result.append(buffer.next())
}
None
result
}
codegenFunc
}

case ProjectExec(projectList, fp @ FilterExec(condition, child)) => {
val unsafeProjection = UnsafeProjection.create(projectList, fp.output)

def projectFunc(row: InternalRow): Option[InternalRow] = {
def projectFunc(row: InternalRow): ArrayBuffer[InternalRow] = {
val r = CatalystHelper.evalFilterExec(row, condition, child.output)
if (r)
Some(unsafeProjection.apply(row))
ArrayBuffer(unsafeProjection.apply(row))
else
None
ArrayBuffer.empty[InternalRow]
}

projectFunc
Expand All @@ -230,33 +233,34 @@ class CatalystUtil(inputSchema: StructType,
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
val iteratorWrapper: IteratorWrapper[InternalRow] = new IteratorWrapper[InternalRow]
buffer.init(0, Array(iteratorWrapper))
def codegenFunc(row: InternalRow): Option[InternalRow] = {
def codegenFunc(row: InternalRow): ArrayBuffer[InternalRow] = {
iteratorWrapper.put(row)
val result = ArrayBuffer.empty[InternalRow]
while (buffer.hasNext) {
return Some(unsafeProjection.apply(buffer.next()))
result.append(unsafeProjection.apply(buffer.next()))
}
None
result
}
codegenFunc
case _ =>
val unsafeProjection = UnsafeProjection.create(projectList, childPlan.output)
def projectFunc(row: InternalRow): Option[InternalRow] = {
Some(unsafeProjection.apply(row))
def projectFunc(row: InternalRow): ArrayBuffer[InternalRow] = {
ArrayBuffer(unsafeProjection.apply(row))
}
projectFunc
}
}
case ltse: LocalTableScanExec => {
// Input `row` is unused because for LTSE, no input is needed to compute the output
def projectFunc(row: InternalRow): Option[InternalRow] =
ltse.executeCollect().headOption
def projectFunc(row: InternalRow): ArrayBuffer[InternalRow] =
ArrayBuffer(ltse.executeCollect(): _*)

projectFunc
}
case rddse: RDDScanExec => {
val unsafeProjection = UnsafeProjection.create(rddse.schema)
def projectFunc(row: InternalRow): Option[InternalRow] =
Some(unsafeProjection.apply(row))
def projectFunc(row: InternalRow): ArrayBuffer[InternalRow] =
ArrayBuffer(unsafeProjection.apply(row))

projectFunc
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ object OnlineDerivationUtil {
): DerivationFunc = {
{
case (keys: Map[String, Any], values: Map[String, Any]) =>
reintroduceExceptions(catalystUtil.performSql(keys ++ values).orNull, values)
reintroduceExceptions(catalystUtil.performSql(keys ++ values).headOption.orNull, values)
}
}

Expand Down
16 changes: 14 additions & 2 deletions online/src/main/scala/ai/chronon/online/stats/PivotUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,22 @@ object PivotUtils {

private def collectDoubles(vals: Iterator[JDouble]): JList[JDouble] = {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another pr broke this - not sure how that went in. so fixing to get the ci passing on this one.

val result = new JArrayList[JDouble]()

var sawValidInput = false

while (vals.hasNext) {
result.add(vals.next())

val next = vals.next()

// check if this is valid input - if no prior inputs are valid
val thisIsValid = next != Constants.magicNullDouble && next != null
sawValidInput = sawValidInput || thisIsValid

result.add(next)
}
result

// if no valid input, return null
if (!sawValidInput) null else result
}

def pivot(driftsWithTimestamps: Array[(TileDrift, Long)]): TileDriftSeries = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import org.scalatest.flatspec.AnyFlatSpec

class CatalystUtilHiveUDFTest extends AnyFlatSpec with CatalystUtilTestSparkSQLStructs with TaggedFilterSuite {

it should "hive ud fs via setups should work" in {
"catalyst util" should "work with hive_udfs via setups should work" in {
val setups = Seq(
"CREATE FUNCTION MINUS_ONE AS 'ai.chronon.online.test.Minus_One'",
"CREATE FUNCTION CAT_STR AS 'ai.chronon.online.test.Cat_Str'",
Expand All @@ -16,10 +16,12 @@ class CatalystUtilHiveUDFTest extends AnyFlatSpec with CatalystUtilTestSparkSQLS
"b" -> "CAT_STR(string_x)"
)
val cu = new CatalystUtil(CommonScalarsStruct, selects = selects, setups = setups)
val res = cu.performSql(CommonScalarsRow)
assertEquals(res.get.size, 2)
assertEquals(res.get("a"), Int.MaxValue - 1)
assertEquals(res.get("b"), "hello123")
val resList = cu.performSql(CommonScalarsRow)
assertEquals(resList.size, 1)
val resMap = resList.head
assertEquals(resMap.size, 2)
assertEquals(resMap("a"), Int.MaxValue - 1)
assertEquals(resMap("b"), "hello123")
}

override def tagName: String = "catalystUtilHiveUdfTest"
Expand Down
Loading
Loading