Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 28 additions & 24 deletions online/src/main/scala/ai/chronon/online/CatalystUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import java.util.concurrent.ConcurrentHashMap
import java.util.function
import scala.collection.Seq
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

object CatalystUtil {
private class IteratorWrapper[T] extends Iterator[T] {
Expand Down Expand Up @@ -109,7 +110,7 @@ class PoolMap[Key, Value](createFunc: Key => Value, maxSize: Int = 100, initialS
class PooledCatalystUtil(expressions: collection.Seq[(String, String)], inputSchema: StructType) {
private val poolKey = PoolKey(expressions, inputSchema)
private val cuPool = poolMap.getPool(PoolKey(expressions, inputSchema))
def performSql(values: Map[String, Any]): Option[Map[String, Any]] =
def performSql(values: Map[String, Any]): Seq[Map[String, Any]] =
poolMap.performWithValue(poolKey, cuPool) { _.performSql(values) }
def outputChrononSchema: Array[(String, DataType)] =
poolMap.performWithValue(poolKey, cuPool) { _.outputChrononSchema }
Expand All @@ -136,34 +137,34 @@ class CatalystUtil(inputSchema: StructType,
private val inputEncoder = SparkInternalRowConversions.to(inputSparkSchema)
private val inputArrEncoder = SparkInternalRowConversions.to(inputSparkSchema, false)

private val (transformFunc: (InternalRow => Option[InternalRow]), outputSparkSchema: types.StructType) = initialize()
private val (transformFunc: (InternalRow => Seq[InternalRow]), outputSparkSchema: types.StructType) = initialize()

private lazy val outputArrDecoder = SparkInternalRowConversions.from(outputSparkSchema, false)
@transient lazy val outputChrononSchema: Array[(String, DataType)] =
SparkConversions.toChrononSchema(outputSparkSchema)
private val outputDecoder = SparkInternalRowConversions.from(outputSparkSchema)

def performSql(values: Array[Any]): Option[Array[Any]] = {
def performSql(values: Array[Any]): Seq[Array[Any]] = {
val internalRow = inputArrEncoder(values).asInstanceOf[InternalRow]
val resultRowOpt = transformFunc(internalRow)
val outputVal = resultRowOpt.map(resultRow => outputArrDecoder(resultRow))
val resultRowSeq = transformFunc(internalRow)
val outputVal = resultRowSeq.map(resultRow => outputArrDecoder(resultRow))
outputVal.map(_.asInstanceOf[Array[Any]])
}

def performSql(values: Map[String, Any]): Option[Map[String, Any]] = {
def performSql(values: Map[String, Any]): Seq[Map[String, Any]] = {
val internalRow = inputEncoder(values).asInstanceOf[InternalRow]
performSql(internalRow)
}

def performSql(row: InternalRow): Option[Map[String, Any]] = {
def performSql(row: InternalRow): Seq[Map[String, Any]] = {
val resultRowMaybe = transformFunc(row)
val outputVal = resultRowMaybe.map(resultRow => outputDecoder(resultRow))
outputVal.map(_.asInstanceOf[Map[String, Any]])
}

def getOutputSparkSchema: types.StructType = outputSparkSchema

private def initialize(): (InternalRow => Option[InternalRow], types.StructType) = {
private def initialize(): (InternalRow => Seq[InternalRow], types.StructType) = {
val session = CatalystUtil.session

// run through and execute the setup statements
Expand All @@ -189,32 +190,34 @@ class CatalystUtil(inputSchema: StructType,
val filteredDf = whereClauseOpt.map(df.where(_)).getOrElse(df)

// extract transform function from the df spark plan
val func: InternalRow => Option[InternalRow] = filteredDf.queryExecution.executedPlan match {
val func: InternalRow => ArrayBuffer[InternalRow] = filteredDf.queryExecution.executedPlan match {
case whc: WholeStageCodegenExec => {
val (ctx, cleanedSource) = whc.doCodeGen()
val (clazz, _) = CodeGenerator.compile(cleanedSource)
val references = ctx.references.toArray
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
val iteratorWrapper: IteratorWrapper[InternalRow] = new IteratorWrapper[InternalRow]
buffer.init(0, Array(iteratorWrapper))
def codegenFunc(row: InternalRow): Option[InternalRow] = {
def codegenFunc(row: InternalRow): ArrayBuffer[InternalRow] = {
iteratorWrapper.put(row)
val result = ArrayBuffer.empty[InternalRow]
while (buffer.hasNext) {
return Some(buffer.next())
result.append(buffer.next())
}
None
result
}
codegenFunc
}

case ProjectExec(projectList, fp @ FilterExec(condition, child)) => {
val unsafeProjection = UnsafeProjection.create(projectList, fp.output)

def projectFunc(row: InternalRow): Option[InternalRow] = {
def projectFunc(row: InternalRow): ArrayBuffer[InternalRow] = {
val r = CatalystHelper.evalFilterExec(row, condition, child.output)
if (r)
Some(unsafeProjection.apply(row))
ArrayBuffer(unsafeProjection.apply(row))
else
None
ArrayBuffer.empty[InternalRow]
}

projectFunc
Expand All @@ -230,33 +233,34 @@ class CatalystUtil(inputSchema: StructType,
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
val iteratorWrapper: IteratorWrapper[InternalRow] = new IteratorWrapper[InternalRow]
buffer.init(0, Array(iteratorWrapper))
def codegenFunc(row: InternalRow): Option[InternalRow] = {
def codegenFunc(row: InternalRow): ArrayBuffer[InternalRow] = {
iteratorWrapper.put(row)
val result = ArrayBuffer.empty[InternalRow]
while (buffer.hasNext) {
return Some(unsafeProjection.apply(buffer.next()))
result.append(unsafeProjection.apply(buffer.next()))
}
None
result
}
codegenFunc
case _ =>
val unsafeProjection = UnsafeProjection.create(projectList, childPlan.output)
def projectFunc(row: InternalRow): Option[InternalRow] = {
Some(unsafeProjection.apply(row))
def projectFunc(row: InternalRow): ArrayBuffer[InternalRow] = {
ArrayBuffer(unsafeProjection.apply(row))
}
projectFunc
}
}
case ltse: LocalTableScanExec => {
// Input `row` is unused because for LTSE, no input is needed to compute the output
def projectFunc(row: InternalRow): Option[InternalRow] =
ltse.executeCollect().headOption
def projectFunc(row: InternalRow): ArrayBuffer[InternalRow] =
ArrayBuffer(ltse.executeCollect(): _*)

projectFunc
}
case rddse: RDDScanExec => {
val unsafeProjection = UnsafeProjection.create(rddse.schema)
def projectFunc(row: InternalRow): Option[InternalRow] =
Some(unsafeProjection.apply(row))
def projectFunc(row: InternalRow): ArrayBuffer[InternalRow] =
ArrayBuffer(unsafeProjection.apply(row))

projectFunc
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ object OnlineDerivationUtil {
): DerivationFunc = {
{
case (keys: Map[String, Any], values: Map[String, Any]) =>
reintroduceExceptions(catalystUtil.performSql(keys ++ values).orNull, values)
reintroduceExceptions(catalystUtil.performSql(keys ++ values).headOption.orNull, values)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ class CatalystUtilHiveUDFTest extends AnyFlatSpec with CatalystUtilTestSparkSQLS
"b" -> "CAT_STR(string_x)"
)
val cu = new CatalystUtil(CommonScalarsStruct, selects = selects, setups = setups)
val res = cu.performSql(CommonScalarsRow)
assertEquals(res.get.size, 2)
assertEquals(res.get("a"), Int.MaxValue - 1)
assertEquals(res.get("b"), "hello123")
val resList = cu.performSql(CommonScalarsRow)
assertEquals(resList.size, 1)
val resMap = resList.head
assertEquals(resMap.size, 2)
assertEquals(resMap.get("a"), Int.MaxValue - 1)
assertEquals(resMap.get("b"), "hello123")
}

override def tagName: String = "catalystUtilHiveUdfTest"
Expand Down
Loading
Loading