diff --git a/flink/src/main/scala/ai/chronon/flink/AsyncKVStoreWriter.scala b/flink/src/main/scala/ai/chronon/flink/AsyncKVStoreWriter.scala index 9574801950..c329280583 100644 --- a/flink/src/main/scala/ai/chronon/flink/AsyncKVStoreWriter.scala +++ b/flink/src/main/scala/ai/chronon/flink/AsyncKVStoreWriter.scala @@ -48,18 +48,6 @@ object AsyncKVStoreWriter { .setParallelism(inputDS.getParallelism) } - /** This was moved to flink-rpc-akka in Flink 1.16 and made private, so we reproduce the direct execution context here - */ - private class DirectExecutionContext extends ExecutionContext { - override def execute(runnable: Runnable): Unit = - runnable.run() - - override def reportFailure(cause: Throwable): Unit = - throw new IllegalStateException("Error in direct execution context.", cause) - - override def prepare: ExecutionContext = this - } - private val ExecutionContextInstance: ExecutionContext = new DirectExecutionContext } diff --git a/flink/src/main/scala/ai/chronon/flink/FlinkGroupByStreamingJob.scala b/flink/src/main/scala/ai/chronon/flink/FlinkGroupByStreamingJob.scala new file mode 100644 index 0000000000..f93048d952 --- /dev/null +++ b/flink/src/main/scala/ai/chronon/flink/FlinkGroupByStreamingJob.scala @@ -0,0 +1,219 @@ +package ai.chronon.flink + +import ai.chronon.aggregator.windowing.ResolutionUtils +import ai.chronon.api.Extensions.{GroupByOps, SourceOps} +import ai.chronon.api.DataType +import ai.chronon.flink.FlinkJob.watermarkStrategy +import ai.chronon.flink.deser.ProjectedEvent +import ai.chronon.flink.source.FlinkSource +import ai.chronon.flink.types.{AvroCodecOutput, TimestampedTile, WriteResponse} +import ai.chronon.flink.window.{ + AlwaysFireOnElementTrigger, + BufferedProcessingTimeTrigger, + FlinkRowAggProcessFunction, + FlinkRowAggregationFunction, + KeySelectorBuilder +} +import ai.chronon.online.{GroupByServingInfoParsed, TopicInfo} +import org.apache.flink.streaming.api.datastream.{DataStream, SingleOutputStreamOperator} +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment +import org.apache.flink.streaming.api.functions.async.RichAsyncFunction +import org.apache.flink.streaming.api.windowing.assigners.{TumblingEventTimeWindows, WindowAssigner} +import org.apache.flink.streaming.api.windowing.time.Time +import org.apache.flink.streaming.api.windowing.triggers.Trigger +import org.apache.flink.streaming.api.windowing.windows.TimeWindow +import org.apache.flink.util.OutputTag + +import scala.collection.Seq + +/** Flink job that processes a single streaming GroupBy and writes out the results (in the form of pre-aggregated tiles) to the KV store. + * + * @param eventSrc - Provider of a Flink Datastream[ ProjectedEvent ] for the given topic and groupBy. The event + * consists of a field Map as well as metadata columns such as processing start time (to track + * metrics). The Map contains projected columns from the source data based on projections and filters + * in the GroupBy. + * @param sinkFn - Async Flink writer function to help us write to the KV store + * @param groupByServingInfoParsed - The GroupBy we are working with + * @param parallelism - Parallelism to use for the Flink job + * @param enableDebug - If enabled will log additional debug info per processed event + */ +class FlinkGroupByStreamingJob(eventSrc: FlinkSource[ProjectedEvent], + inputSchema: Seq[(String, DataType)], + sinkFn: RichAsyncFunction[AvroCodecOutput, WriteResponse], + val groupByServingInfoParsed: GroupByServingInfoParsed, + parallelism: Int, + props: Map[String, String], + topicInfo: TopicInfo, + enableDebug: Boolean = false) + extends BaseFlinkJob { + + val groupByName: String = groupByServingInfoParsed.groupBy.getMetaData.getName + logger.info(f"Creating Flink GroupBy streaming job. groupByName=${groupByName}") + + if (groupByServingInfoParsed.groupBy.streamingSource.isEmpty) { + throw new IllegalArgumentException( + s"Invalid groupBy: $groupByName. No streaming source" + ) + } + + private val kvStoreCapacity = FlinkUtils + .getProperty("kv_concurrency", props, topicInfo) + .map(_.toInt) + .getOrElse(AsyncKVStoreWriter.kvStoreConcurrency) + + // The source of our Flink application is a topic + val topic: String = groupByServingInfoParsed.groupBy.streamingSource.get.topic + + /** The "untiled" version of the Flink app. + * + * At a high level, the operators are structured as follows: + * source -> Spark expression eval -> Avro conversion -> KV store writer + * source - Reads objects of type T (specific case class, Thrift / Proto) from a topic + * Spark expression eval - Evaluates the Spark SQL expression in the GroupBy and projects and filters the input data + * Avro conversion - Converts the Spark expr eval output to a form that can be written out to the KV store + * (PutRequest object) + * KV store writer - Writes the PutRequest objects to the KV store using the AsyncDataStream API + * + * In this untiled version, there are no shuffles and thus this ends up being a single node in the Flink DAG + * (with the above 4 operators and parallelism as injected by the user). + */ + def runGroupByJob(env: StreamExecutionEnvironment): DataStream[WriteResponse] = { + + logger.info( + f"Running Flink job for groupByName=${groupByName}, Topic=${topic}. " + + "Tiling is disabled.") + + // we expect parallelism on the source stream to be set by the source provider + val sourceSparkProjectedStream: DataStream[ProjectedEvent] = + eventSrc + .getDataStream(topic, groupByName)(env, parallelism) + .uid(s"source-$groupByName") + .name(s"Source for $groupByName") + + val sparkExprEvalDSWithWatermarks: DataStream[ProjectedEvent] = sourceSparkProjectedStream + .assignTimestampsAndWatermarks(watermarkStrategy) + .uid(s"spark-expr-eval-timestamps-$groupByName") + .name(s"Spark expression eval with timestamps for $groupByName") + .setParallelism(sourceSparkProjectedStream.getParallelism) + + val putRecordDS: DataStream[AvroCodecOutput] = sparkExprEvalDSWithWatermarks + .flatMap(AvroCodecFn(groupByServingInfoParsed)) + .uid(s"avro-conversion-$groupByName") + .name(s"Avro conversion for $groupByName") + .setParallelism(sourceSparkProjectedStream.getParallelism) + + AsyncKVStoreWriter.withUnorderedWaits( + putRecordDS, + sinkFn, + groupByName, + capacity = kvStoreCapacity + ) + } + + /** The "tiled" version of the Flink app. + * + * The operators are structured as follows: + * 1. source - Reads objects of type T (specific case class, Thrift / Proto) from a topic + * 2. Spark expression eval - Evaluates the Spark SQL expression in the GroupBy and projects and filters the input + * data + * 3. Window/tiling - This window aggregates incoming events, keeps track of the IRs, and sends them forward so + * they are written out to the KV store + * 4. Avro conversion - Finishes converting the output of the window (the IRs) to a form that can be written out + * to the KV store (PutRequest object) + * 5. KV store writer - Writes the PutRequest objects to the KV store using the AsyncDataStream API + * + * The window causes a split in the Flink DAG, so there are two nodes, (1+2) and (3+4+5). + */ + override def runTiledGroupByJob(env: StreamExecutionEnvironment): DataStream[WriteResponse] = { + logger.info( + f"Running Flink job for groupByName=${groupByName}, Topic=${topic}. " + + "Tiling is enabled.") + + val tilingWindowSizeInMillis: Long = + ResolutionUtils.getSmallestTailHopMillis(groupByServingInfoParsed.groupBy) + + // we expect parallelism on the source stream to be set by the source provider + val sourceSparkProjectedStream: DataStream[ProjectedEvent] = + eventSrc + .getDataStream(topic, groupByName)(env, parallelism) + .uid(s"source-$groupByName") + .name(s"Source for $groupByName") + + val sparkExprEvalDSAndWatermarks: DataStream[ProjectedEvent] = sourceSparkProjectedStream + .assignTimestampsAndWatermarks(watermarkStrategy) + .uid(s"spark-expr-eval-timestamps-$groupByName") + .name(s"Spark expression eval with timestamps for $groupByName") + .setParallelism(sourceSparkProjectedStream.getParallelism) + + val window = TumblingEventTimeWindows + .of(Time.milliseconds(tilingWindowSizeInMillis)) + .asInstanceOf[WindowAssigner[ProjectedEvent, TimeWindow]] + + // We default to the AlwaysFireOnElementTrigger which will cause the window to "FIRE" on every element. + // An alternative is the BufferedProcessingTimeTrigger (trigger=buffered in topic info + // or properties) which will buffer writes and only "FIRE" every X milliseconds per GroupBy & key. + val trigger = getTrigger() + + // We use Flink "Side Outputs" to track any late events that aren't computed. + val tilingLateEventsTag = new OutputTag[ProjectedEvent]("tiling-late-events") {} + + // The tiling operator works the following way: + // 1. Input: Spark expression eval (previous operator) + // 2. Key by the entity key(s) defined in the groupby + // 3. Window by a tumbling window + // 4. Use our custom trigger that will "FIRE" on every element + // 5. the AggregationFunction merges each incoming element with the current IRs which are kept in state + // - Each time a "FIRE" is triggered (i.e. on every event), getResult() is called and the current IRs are emitted + // 6. A process window function does additional processing each time the AggregationFunction emits results + // - The only purpose of this window function is to mark tiles as closed so we can do client-side caching in SFS + // 7. Output: TimestampedTile, containing the current IRs (Avro encoded) and the timestamp of the current element + + val tilingDS: SingleOutputStreamOperator[TimestampedTile] = + sparkExprEvalDSAndWatermarks + .keyBy(KeySelectorBuilder.build(groupByServingInfoParsed.groupBy)) + .window(window) + .trigger(trigger) + .sideOutputLateData(tilingLateEventsTag) + .aggregate( + // See Flink's "ProcessWindowFunction with Incremental Aggregation" + new FlinkRowAggregationFunction(groupByServingInfoParsed.groupBy, inputSchema, enableDebug), + new FlinkRowAggProcessFunction(groupByServingInfoParsed.groupBy, inputSchema, enableDebug) + ) + .uid(s"tiling-01-$groupByName") + .name(s"Tiling for $groupByName") + .setParallelism(sourceSparkProjectedStream.getParallelism) + + // Track late events + tilingDS + .getSideOutput(tilingLateEventsTag) + .flatMap(new LateEventCounter(groupByName)) + .uid(s"tiling-side-output-01-$groupByName") + .name(s"Tiling Side Output Late Data for $groupByName") + .setParallelism(sourceSparkProjectedStream.getParallelism) + + val putRecordDS: DataStream[AvroCodecOutput] = tilingDS + .flatMap(TiledAvroCodecFn(groupByServingInfoParsed, tilingWindowSizeInMillis, enableDebug)) + .uid(s"avro-conversion-01-$groupByName") + .name(s"Avro conversion for $groupByName") + .setParallelism(sourceSparkProjectedStream.getParallelism) + + AsyncKVStoreWriter.withUnorderedWaits( + putRecordDS, + sinkFn, + groupByName, + capacity = kvStoreCapacity + ) + } + + private def getTrigger(): Trigger[ProjectedEvent, TimeWindow] = { + FlinkUtils + .getProperty("trigger", props, topicInfo) + .map { + case "always_fire" => new AlwaysFireOnElementTrigger() + case "buffered" => new BufferedProcessingTimeTrigger(100L) + case t => + throw new IllegalArgumentException(s"Unsupported trigger type: $t. Supported: 'always_fire', 'buffered'") + } + .getOrElse(new AlwaysFireOnElementTrigger()) + } +} diff --git a/flink/src/main/scala/ai/chronon/flink/FlinkJob.scala b/flink/src/main/scala/ai/chronon/flink/FlinkJob.scala index 0737dd3baf..90d52e5127 100644 --- a/flink/src/main/scala/ai/chronon/flink/FlinkJob.scala +++ b/flink/src/main/scala/ai/chronon/flink/FlinkJob.scala @@ -1,22 +1,14 @@ package ai.chronon.flink -import ai.chronon.aggregator.windowing.ResolutionUtils import ai.chronon.api.Constants.MetadataDataset import ai.chronon.api.Extensions.{GroupByOps, SourceOps} -import ai.chronon.api.ScalaJavaConversions._ -import ai.chronon.api.{Constants, DataType} -import ai.chronon.flink.FlinkJob.watermarkStrategy +import ai.chronon.api.{Constants, DataModel} +import ai.chronon.flink.{AsyncKVStoreWriter, FlinkGroupByStreamingJob} import ai.chronon.flink.deser.{DeserializationSchemaBuilder, FlinkSerDeProvider, ProjectedEvent, SourceProjection} -import ai.chronon.flink.source.{FlinkSource, FlinkSourceProvider, KafkaFlinkSource} -import ai.chronon.flink.types.{AvroCodecOutput, TimestampedTile, WriteResponse} +import ai.chronon.flink.chaining.ChainedGroupByJob +import ai.chronon.flink.source.FlinkSourceProvider +import ai.chronon.flink.types.WriteResponse import ai.chronon.flink.validation.ValidationFlinkJob -import ai.chronon.flink.window.{ - AlwaysFireOnElementTrigger, - BufferedProcessingTimeTrigger, - FlinkRowAggProcessFunction, - FlinkRowAggregationFunction, - KeySelectorBuilder -} import ai.chronon.online.fetcher.{FetchContext, MetadataStore} import ai.chronon.online.{Api, GroupByServingInfoParsed, TopicInfo} import org.apache.flink.api.common.eventtime.{SerializableTimestampAssigner, WatermarkStrategy} @@ -25,63 +17,43 @@ import org.apache.flink.api.common.restartstrategy.RestartStrategies import org.apache.flink.configuration.Configuration import org.apache.flink.core.fs.FileSystem import org.apache.flink.streaming.api.CheckpointingMode -import org.apache.flink.streaming.api.datastream.{DataStream, DataStreamSink, SingleOutputStreamOperator} +import org.apache.flink.streaming.api.datastream.{DataStream, DataStreamSink} import org.apache.flink.streaming.api.environment.CheckpointConfig.ExternalizedCheckpointCleanup import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment -import org.apache.flink.streaming.api.functions.async.RichAsyncFunction -import org.apache.flink.streaming.api.windowing.assigners.{TumblingEventTimeWindows, WindowAssigner} -import org.apache.flink.streaming.api.windowing.time.Time -import org.apache.flink.streaming.api.windowing.triggers.Trigger -import org.apache.flink.streaming.api.windowing.windows.TimeWindow -import org.apache.flink.util.OutputTag import org.rogach.scallop.{ScallopConf, ScallopOption, Serialization} -import org.slf4j.LoggerFactory +import org.slf4j.{Logger, LoggerFactory} import java.time.Duration import scala.collection.Seq import scala.concurrent.duration.{DurationInt, FiniteDuration} -/** Flink job that processes a single streaming GroupBy and writes out the results (in the form of pre-aggregated tiles) to the KV store. - * - * @param eventSrc - Provider of a Flink Datastream[ ProjectedEvent ] for the given topic and groupBy. The event - * consists of a field Map as well as metadata columns such as processing start time (to track - * metrics). The Map contains projected columns from the source data based on projections and filters - * in the GroupBy. - * @param sinkFn - Async Flink writer function to help us write to the KV store - * @param groupByServingInfoParsed - The GroupBy we are working with - * @param parallelism - Parallelism to use for the Flink job - * @param enableDebug - If enabled will log additional debug info per processed event +/** Base abstract class for all Flink streaming jobs in Chronon. + * Defines the common interface and shared functionality for different job types. */ -class FlinkJob(eventSrc: FlinkSource[ProjectedEvent], - inputSchema: Seq[(String, DataType)], - sinkFn: RichAsyncFunction[AvroCodecOutput, WriteResponse], - groupByServingInfoParsed: GroupByServingInfoParsed, - parallelism: Int, - props: Map[String, String], - topicInfo: TopicInfo, - enableDebug: Boolean = false) { - private[this] val logger = LoggerFactory.getLogger(getClass) - - val groupByName: String = groupByServingInfoParsed.groupBy.getMetaData.getName - logger.info(f"Creating Flink job. groupByName=${groupByName}") - - if (groupByServingInfoParsed.groupBy.streamingSource.isEmpty) { - throw new IllegalArgumentException( - s"Invalid groupBy: $groupByName. No streaming source" - ) - } +abstract class BaseFlinkJob { + + protected val logger: Logger = LoggerFactory.getLogger(getClass) + + def groupByName: String - private val kvStoreCapacity = FlinkUtils - .getProperty("kv_concurrency", props, topicInfo) - .map(_.toInt) - .getOrElse(AsyncKVStoreWriter.kvStoreConcurrency) + def groupByServingInfoParsed: GroupByServingInfoParsed - // The source of our Flink application is a topic - val topic: String = groupByServingInfoParsed.groupBy.streamingSource.get.topic + /** Run the streaming job with tiling enabled (default mode). + * This is the main execution method that should be implemented by subclasses. + */ + def runTiledGroupByJob(env: StreamExecutionEnvironment): DataStream[WriteResponse] +} + +object FlinkJob { + // we set an explicit max parallelism to ensure if we do make parallelism setting updates, there's still room + // to restore the job from prior state. Number chosen does have perf ramifications if too high (can impact rocksdb perf) + // so we've chosen one that should allow us to scale to jobs in the 10K-50K events / s range. + val MaxParallelism: Int = 1260 // highly composite number def runWriteInternalManifestJob(env: StreamExecutionEnvironment, manifestPath: String, - parentJobId: String): DataStreamSink[String] = { + parentJobId: String, + groupByName: String): DataStreamSink[String] = { // check that the last character is a slash val outputPath = if (!manifestPath.endsWith("/")) { @@ -109,166 +81,6 @@ class FlinkJob(eventSrc: FlinkSource[ProjectedEvent], .setParallelism(1) // Use parallelism 1 to get a single output file } - /** The "untiled" version of the Flink app. - * - * At a high level, the operators are structured as follows: - * source -> Spark expression eval -> Avro conversion -> KV store writer - * source - Reads objects of type T (specific case class, Thrift / Proto) from a topic - * Spark expression eval - Evaluates the Spark SQL expression in the GroupBy and projects and filters the input data - * Avro conversion - Converts the Spark expr eval output to a form that can be written out to the KV store - * (PutRequest object) - * KV store writer - Writes the PutRequest objects to the KV store using the AsyncDataStream API - * - * In this untiled version, there are no shuffles and thus this ends up being a single node in the Flink DAG - * (with the above 4 operators and parallelism as injected by the user). - */ - def runGroupByJob(env: StreamExecutionEnvironment): DataStream[WriteResponse] = { - - logger.info( - f"Running Flink job for groupByName=${groupByName}, Topic=${topic}. " + - "Tiling is disabled.") - - // we expect parallelism on the source stream to be set by the source provider - val sourceSparkProjectedStream: DataStream[ProjectedEvent] = - eventSrc - .getDataStream(topic, groupByName)(env, parallelism) - .uid(s"source-$groupByName") - .name(s"Source for $groupByName") - - val sparkExprEvalDSWithWatermarks: DataStream[ProjectedEvent] = sourceSparkProjectedStream - .assignTimestampsAndWatermarks(watermarkStrategy) - .uid(s"spark-expr-eval-timestamps-$groupByName") - .name(s"Spark expression eval with timestamps for $groupByName") - .setParallelism(sourceSparkProjectedStream.getParallelism) - - val putRecordDS: DataStream[AvroCodecOutput] = sparkExprEvalDSWithWatermarks - .flatMap(AvroCodecFn(groupByServingInfoParsed)) - .uid(s"avro-conversion-$groupByName") - .name(s"Avro conversion for $groupByName") - .setParallelism(sourceSparkProjectedStream.getParallelism) - - AsyncKVStoreWriter.withUnorderedWaits( - putRecordDS, - sinkFn, - groupByName, - capacity = kvStoreCapacity - ) - } - - /** The "tiled" version of the Flink app. - * - * The operators are structured as follows: - * 1. source - Reads objects of type T (specific case class, Thrift / Proto) from a topic - * 2. Spark expression eval - Evaluates the Spark SQL expression in the GroupBy and projects and filters the input - * data - * 3. Window/tiling - This window aggregates incoming events, keeps track of the IRs, and sends them forward so - * they are written out to the KV store - * 4. Avro conversion - Finishes converting the output of the window (the IRs) to a form that can be written out - * to the KV store (PutRequest object) - * 5. KV store writer - Writes the PutRequest objects to the KV store using the AsyncDataStream API - * - * The window causes a split in the Flink DAG, so there are two nodes, (1+2) and (3+4+5). - */ - def runTiledGroupByJob(env: StreamExecutionEnvironment): DataStream[WriteResponse] = { - logger.info( - f"Running Flink job for groupByName=${groupByName}, Topic=${topic}. " + - "Tiling is enabled.") - - val tilingWindowSizeInMillis: Long = - ResolutionUtils.getSmallestTailHopMillis(groupByServingInfoParsed.groupBy) - - // we expect parallelism on the source stream to be set by the source provider - val sourceSparkProjectedStream: DataStream[ProjectedEvent] = - eventSrc - .getDataStream(topic, groupByName)(env, parallelism) - .uid(s"source-$groupByName") - .name(s"Source for $groupByName") - - val sparkExprEvalDSAndWatermarks: DataStream[ProjectedEvent] = sourceSparkProjectedStream - .assignTimestampsAndWatermarks(watermarkStrategy) - .uid(s"spark-expr-eval-timestamps-$groupByName") - .name(s"Spark expression eval with timestamps for $groupByName") - .setParallelism(sourceSparkProjectedStream.getParallelism) - - val window = TumblingEventTimeWindows - .of(Time.milliseconds(tilingWindowSizeInMillis)) - .asInstanceOf[WindowAssigner[ProjectedEvent, TimeWindow]] - - // We default to the AlwaysFireOnElementTrigger which will cause the window to "FIRE" on every element. - // An alternative is the BufferedProcessingTimeTrigger (trigger=buffered in topic info - // or properties) which will buffer writes and only "FIRE" every X milliseconds per GroupBy & key. - val trigger = getTrigger() - - // We use Flink "Side Outputs" to track any late events that aren't computed. - val tilingLateEventsTag = new OutputTag[ProjectedEvent]("tiling-late-events") {} - - // The tiling operator works the following way: - // 1. Input: Spark expression eval (previous operator) - // 2. Key by the entity key(s) defined in the groupby - // 3. Window by a tumbling window - // 4. Use our custom trigger that will "FIRE" on every element - // 5. the AggregationFunction merges each incoming element with the current IRs which are kept in state - // - Each time a "FIRE" is triggered (i.e. on every event), getResult() is called and the current IRs are emitted - // 6. A process window function does additional processing each time the AggregationFunction emits results - // - The only purpose of this window function is to mark tiles as closed so we can do client-side caching in SFS - // 7. Output: TimestampedTile, containing the current IRs (Avro encoded) and the timestamp of the current element - - val tilingDS: SingleOutputStreamOperator[TimestampedTile] = - sparkExprEvalDSAndWatermarks - .keyBy(KeySelectorBuilder.build(groupByServingInfoParsed.groupBy)) - .window(window) - .trigger(trigger) - .sideOutputLateData(tilingLateEventsTag) - .aggregate( - // See Flink's "ProcessWindowFunction with Incremental Aggregation" - new FlinkRowAggregationFunction(groupByServingInfoParsed.groupBy, inputSchema, enableDebug), - new FlinkRowAggProcessFunction(groupByServingInfoParsed.groupBy, inputSchema, enableDebug) - ) - .uid(s"tiling-01-$groupByName") - .name(s"Tiling for $groupByName") - .setParallelism(sourceSparkProjectedStream.getParallelism) - - // Track late events - tilingDS - .getSideOutput(tilingLateEventsTag) - .flatMap(new LateEventCounter(groupByName)) - .uid(s"tiling-side-output-01-$groupByName") - .name(s"Tiling Side Output Late Data for $groupByName") - .setParallelism(sourceSparkProjectedStream.getParallelism) - - val putRecordDS: DataStream[AvroCodecOutput] = tilingDS - .flatMap(TiledAvroCodecFn(groupByServingInfoParsed, tilingWindowSizeInMillis, enableDebug)) - .uid(s"avro-conversion-01-$groupByName") - .name(s"Avro conversion for $groupByName") - .setParallelism(sourceSparkProjectedStream.getParallelism) - - AsyncKVStoreWriter.withUnorderedWaits( - putRecordDS, - sinkFn, - groupByName, - capacity = kvStoreCapacity - ) - } - - private def getTrigger(): Trigger[ProjectedEvent, TimeWindow] = { - FlinkUtils - .getProperty("trigger", props, topicInfo) - .map { - case "always_fire" => new AlwaysFireOnElementTrigger() - case "buffered" => new BufferedProcessingTimeTrigger(100L) - case t => - throw new IllegalArgumentException(s"Unsupported trigger type: $t. Supported: 'always_fire', 'buffered'") - } - .getOrElse(new AlwaysFireOnElementTrigger()) - } -} - -object FlinkJob { - // we set an explicit max parallelism to ensure if we do make parallelism setting updates, there's still room - // to restore the job from prior state. Number chosen does have perf ramifications if too high (can impact rocksdb perf) - // so we've chosen one that should allow us to scale to jobs in the 10K-50K events / s range. - val MaxParallelism: Int = 1260 // highly composite number - // We choose to checkpoint frequently to ensure the incremental checkpoints are small in size // as well as ensuring the catch-up backlog is fairly small in case of failures val CheckPointInterval: FiniteDuration = 10.seconds @@ -409,32 +221,114 @@ object FlinkJob { // Store the mapping between parent job id and flink job id to bucket if (maybeParentJobId.isDefined) { - flinkJob.runWriteInternalManifestJob(env, jobArgs.streamingManifestPath(), maybeParentJobId.get) + FlinkJob.runWriteInternalManifestJob(env, jobArgs.streamingManifestPath(), maybeParentJobId.get, groupByName) } val jobDatastream = flinkJob.runTiledGroupByJob(env) jobDatastream - .addSink(new MetricsSink(flinkJob.groupByName)) - .uid(s"metrics-sink - ${flinkJob.groupByName}") - .name(s"Metrics Sink for ${flinkJob.groupByName}") + .addSink(new MetricsSink(groupByName)) + .uid(s"metrics-sink - $groupByName") + .name(s"Metrics Sink for $groupByName") .setParallelism(jobDatastream.getParallelism) - env.execute(s"${flinkJob.groupByName}") + env.execute(s"$groupByName") } private def buildFlinkJob(groupByName: String, props: Map[String, String], api: Api, servingInfo: GroupByServingInfoParsed, - enableDebug: Boolean = false) = { - val topicUri = servingInfo.groupBy.streamingSource.get.topic - val topicInfo = TopicInfo.parse(topicUri) + enableDebug: Boolean = false): BaseFlinkJob = { + // Check if this is a JoinSource GroupBy + if (servingInfo.groupBy.streamingSource.get.isSetJoinSource) { + buildJoinSourceFlinkJob(groupByName, props, api, servingInfo, enableDebug) + } else { + buildGroupByStreamingJob(groupByName, props, api, servingInfo, enableDebug) + } + } + private def buildJoinSourceFlinkJob(groupByName: String, + props: Map[String, String], + api: Api, + servingInfo: GroupByServingInfoParsed, + enableDebug: Boolean): ChainedGroupByJob = { + val logger = LoggerFactory.getLogger(getClass) + + val joinSource = servingInfo.groupBy.streamingSource.get.getJoinSource + val leftSource = joinSource.getJoin.getLeft + require( + leftSource.dataModel == DataModel.EVENTS, + s"Data model on the left source for chaining jobs must be EVENTs. " + + s"Found: ${leftSource.dataModel} for groupBy: $groupByName" + ) + + val topicInfo = TopicInfo.parse(leftSource.topic) val schemaProvider = FlinkSerDeProvider.build(topicInfo) - val deserializationSchema = - DeserializationSchemaBuilder.buildSourceProjectionDeserSchema(schemaProvider, servingInfo.groupBy, enableDebug) + // Use left source query for deserialization schema - this is the topic & schema we use to drive + // the JoinSource processing + val leftSourceQuery = leftSource.query + val leftSourceGroupByName = s"left_source_${joinSource.getJoin.getMetaData.getName}" + + val deserializationSchema = DeserializationSchemaBuilder.buildSourceProjectionDeserSchema( + schemaProvider, + leftSourceQuery, + leftSourceGroupByName, + DataModel.EVENTS, + enableDebug + ) + + require( + deserializationSchema.isInstanceOf[SourceProjection], + s"Expect created deserialization schema for left source: $leftSourceGroupByName with $topicInfo to mixin SourceProjection. " + + s"We got: ${deserializationSchema.getClass.getSimpleName}" + ) + + val projectedSchema = + try { + deserializationSchema.asInstanceOf[SourceProjection].projectedSchema + } catch { + case _: Exception => + throw new RuntimeException( + s"Failed to perform projection via Spark SQL eval for groupBy: $groupByName. Retrieved event schema: \n${schemaProvider.schema}\n" + + s"Make sure the Spark SQL expressions are valid (e.g. column names match the source event schema).") + } + + val source = FlinkSourceProvider.build(props, deserializationSchema, topicInfo) + val sinkFn = new AsyncKVStoreWriter(api, servingInfo.groupBy.metaData.name, enableDebug) + + logger.info(s"Building JoinSource GroupBy: $groupByName. Using FlinkJoinSourceJob.") + new ChainedGroupByJob( + eventSrc = source, + inputSchema = projectedSchema, + sinkFn = sinkFn, + groupByServingInfoParsed = servingInfo, + parallelism = source.parallelism, + props = props, + topicInfo = topicInfo, + api = api, + enableDebug = enableDebug + ) + } + + private def buildGroupByStreamingJob(groupByName: String, + props: Map[String, String], + api: Api, + servingInfo: GroupByServingInfoParsed, + enableDebug: Boolean): FlinkGroupByStreamingJob = { + val logger = LoggerFactory.getLogger(getClass) + + val topicInfo = TopicInfo.parse(servingInfo.groupBy.streamingSource.get.topic) + val schemaProvider = FlinkSerDeProvider.build(topicInfo) + + // Use the existing GroupBy-based interface for regular GroupBys + val deserializationSchema = DeserializationSchemaBuilder.buildSourceProjectionDeserSchema( + schemaProvider, + servingInfo.groupBy, + enableDebug + ) + require( deserializationSchema.isInstanceOf[SourceProjection], s"Expect created deserialization schema for groupBy: $groupByName with $topicInfo to mixin SourceProjection. " + @@ -452,11 +346,13 @@ object FlinkJob { } val source = FlinkSourceProvider.build(props, deserializationSchema, topicInfo) + val sinkFn = new AsyncKVStoreWriter(api, servingInfo.groupBy.metaData.name, enableDebug) - new FlinkJob( + logger.info(s"Building regular GroupBy: $groupByName. Using FlinkGroupByStreamingJob.") + new FlinkGroupByStreamingJob( eventSrc = source, - projectedSchema, - sinkFn = new AsyncKVStoreWriter(api, servingInfo.groupBy.metaData.name, enableDebug), + inputSchema = projectedSchema, + sinkFn = sinkFn, groupByServingInfoParsed = servingInfo, parallelism = source.parallelism, props = props, diff --git a/flink/src/main/scala/ai/chronon/flink/FlinkUtils.scala b/flink/src/main/scala/ai/chronon/flink/FlinkUtils.scala index 88f345b915..3dfe1fe6f1 100644 --- a/flink/src/main/scala/ai/chronon/flink/FlinkUtils.scala +++ b/flink/src/main/scala/ai/chronon/flink/FlinkUtils.scala @@ -2,6 +2,8 @@ package ai.chronon.flink import ai.chronon.online.TopicInfo +import scala.concurrent.ExecutionContext + object FlinkUtils { def getProperty(key: String, props: Map[String, String], topicInfo: TopicInfo): Option[String] = { @@ -13,3 +15,15 @@ object FlinkUtils { } } } + +/** This was moved to flink-rpc-akka in Flink 1.16 and made private, so we reproduce the direct execution context here + */ +private class DirectExecutionContext extends ExecutionContext { + override def execute(runnable: Runnable): Unit = + runnable.run() + + override def reportFailure(cause: Throwable): Unit = + throw new IllegalStateException("Error in direct execution context.", cause) + + override def prepare: ExecutionContext = this +} diff --git a/flink/src/main/scala/ai/chronon/flink/SparkExpressionEval.scala b/flink/src/main/scala/ai/chronon/flink/SparkExpressionEval.scala index b5fac19638..40f7087adb 100644 --- a/flink/src/main/scala/ai/chronon/flink/SparkExpressionEval.scala +++ b/flink/src/main/scala/ai/chronon/flink/SparkExpressionEval.scala @@ -18,31 +18,36 @@ import org.slf4j.Logger import org.slf4j.LoggerFactory import scala.collection.Seq -import scala.jdk.CollectionConverters.asScalaBufferConverter +import scala.jdk.CollectionConverters._ /** Core utility class for Spark expression evaluation that can be reused across different Flink operators. * This evaluator is instantiated for a given EventType (specific case class object, Thrift / Proto object). - * Based on the selects and where clauses in the GroupBy, this function projects and filters the input data and - * emits a Map which contains the relevant fields & values that are needed to compute the aggregated values for the - * GroupBy. + * Based on the selects and where clauses in the Query, this function projects and filters the input data and + * emits a Map which contains the relevant fields & values that are needed to compute the aggregated values. * This class is meant to be used in Flink operators (e.g. DeserializationSchema, RichMapFunctions) to run Spark expression evals. * * @param encoder Spark Encoder for the input event - * @param groupBy The GroupBy to evaluate. + * @param query The Query to evaluate + * @param groupByName Name for metrics and schema identification * @tparam EventType The type of the input event. */ -class SparkExpressionEval[EventType](encoder: Encoder[EventType], groupBy: GroupBy) extends Serializable { +class SparkExpressionEval[EventType](encoder: Encoder[EventType], + query: Query, + groupByName: String, + dataModel: DataModel = DataModel.EVENTS) + extends Serializable { + import SparkExpressionEval._ @transient private lazy val logger: Logger = LoggerFactory.getLogger(getClass) - private val (transforms, filters) = buildQueryTransformsAndFilters(groupBy) + private val (transforms, filters) = buildQueryTransformsAndFilters(query, dataModel) // Chronon's CatalystUtil expects a Chronon `StructType` so we convert the // Encoder[T]'s schema to one. val chrononSchema: ChrononStructType = ChrononStructType.from( - s"${groupBy.metaData.cleanName}", + groupByName, SparkConversions.toChrononSchema(encoder.schema) ) @@ -73,7 +78,8 @@ class SparkExpressionEval[EventType](encoder: Encoder[EventType], groupBy: Group exprEvalErrorCounter = metricsGroup.counter("spark_expr_eval_errors") // Initialize CatalystUtil without acquiring session reference - catalystUtil = new CatalystUtil(chrononSchema, transforms, filters, groupBy.setups) + val setups = Option(query.setups).map(_.asScala).getOrElse(Seq.empty) + catalystUtil = new CatalystUtil(chrononSchema, transforms, filters, setups) } def performSql(row: InternalRow): Seq[Map[String, Any]] = { @@ -107,7 +113,8 @@ class SparkExpressionEval[EventType](encoder: Encoder[EventType], groupBy: Group } def getOutputSchema: StructType = { - new CatalystUtil(chrononSchema, transforms, filters, groupBy.setups).getOutputSparkSchema + val setups = Option(query.setups).map(_.asScala).getOrElse(Seq.empty) + new CatalystUtil(chrononSchema, transforms, filters, setups).getOutputSparkSchema } def close(): Unit = { @@ -185,23 +192,21 @@ class SparkExpressionEval[EventType](encoder: Encoder[EventType], groupBy: Group } object SparkExpressionEval { - def validateQuery(gb: GroupBy): Query = { + def queryFromGroupBy(gb: GroupBy): Query = { require(gb.streamingSource.isDefined, s"Streaming source is missing in GroupBy: ${gb.metaData.cleanName}") val query = gb.streamingSource.get.query require(query != null, s"Streaming query is missing in GroupBy: ${gb.metaData.cleanName}") query } - private def buildQueryTransformsAndFilters(gb: GroupBy): (Seq[(String, String)], Seq[String]) = { - val query = validateQuery(gb) - + def buildQueryTransformsAndFilters(query: Query, + dataModel: DataModel = DataModel.EVENTS): (Seq[(String, String)], Seq[String]) = { val timeColumn = Option(query.timeColumn).getOrElse(Constants.TimeColumn) val reversalColumn = Option(query.reversalColumn).getOrElse(Constants.ReversalColumn) - val mutationTimeColumn = - Option(query.mutationTimeColumn).getOrElse(Constants.MutationTimeColumn) + val mutationTimeColumn = Option(query.mutationTimeColumn).getOrElse(Constants.MutationTimeColumn) val selects = Option(query.selects).map(_.toScala).getOrElse(Map.empty[String, String]) - val transforms: Seq[(String, String)] = gb.dataModel match { + val transforms: Seq[(String, String)] = dataModel match { case DataModel.EVENTS => (selects ++ Map(Constants.TimeColumn -> timeColumn)).toSeq case DataModel.ENTITIES => @@ -212,7 +217,7 @@ object SparkExpressionEval { )).toSeq } - val timeFilters = gb.dataModel match { + val timeFilters = dataModel match { case DataModel.ENTITIES => Seq(s"${Constants.MutationTimeColumn} is NOT NULL", s"$timeColumn is NOT NULL") case DataModel.EVENTS => Seq(s"$timeColumn is NOT NULL") } diff --git a/flink/src/main/scala/ai/chronon/flink/SparkExpressionEvalFn.scala b/flink/src/main/scala/ai/chronon/flink/SparkExpressionEvalFn.scala index d37ac2b822..d31f352e89 100644 --- a/flink/src/main/scala/ai/chronon/flink/SparkExpressionEvalFn.scala +++ b/flink/src/main/scala/ai/chronon/flink/SparkExpressionEvalFn.scala @@ -1,6 +1,6 @@ package ai.chronon.flink -import ai.chronon.api.GroupBy +import ai.chronon.api.{DataModel, Query} import org.apache.flink.api.common.functions.RichFlatMapFunction import org.apache.flink.configuration.Configuration import org.apache.flink.util.Collector @@ -10,16 +10,21 @@ import org.slf4j.Logger import org.slf4j.LoggerFactory import scala.collection.Seq -/** A Flink function that uses Chronon's CatalystUtil (via the SparkExpressionEval) to evaluate the Spark SQL expression in a GroupBy. +/** A Flink function that uses Chronon's CatalystUtil (via the SparkExpressionEval) to evaluate the Spark SQL expression. * This function is instantiated for a given type T (specific case class object, Thrift / Proto object). - * Based on the selects and where clauses in the GroupBy, this function projects and filters the input data and - * emits a Map which contains the relevant fields & values that are needed to compute the aggregated values for the - * GroupBy. + * Based on the selects and where clauses in the Query, this function projects and filters the input data and + * emits a Map which contains the relevant fields & values that are needed to compute the aggregated values. * @param encoder Spark Encoder for the input data type - * @param groupBy The GroupBy to evaluate. + * @param query The Query to evaluate + * @param groupByName Name to use for metrics and logging * @tparam T The type of the input data. */ -class SparkExpressionEvalFn[T](encoder: Encoder[T], groupBy: GroupBy) extends RichFlatMapFunction[T, Map[String, Any]] { +class SparkExpressionEvalFn[T](encoder: Encoder[T], + query: Query, + groupByName: String, + dataModel: DataModel = DataModel.EVENTS) + extends RichFlatMapFunction[T, Map[String, Any]] { + @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) @transient private var evaluator: SparkExpressionEval[T] = _ @@ -31,11 +36,12 @@ class SparkExpressionEvalFn[T](encoder: Encoder[T], groupBy: GroupBy) extends Ri val eventExprEncoder = encoder.asInstanceOf[ExpressionEncoder[T]] rowSerializer = eventExprEncoder.createSerializer() - evaluator = new SparkExpressionEval[T](encoder, groupBy) + // Create evaluator with simplified constructor + evaluator = new SparkExpressionEval[T](encoder, query, groupByName, dataModel) val metricsGroup = getRuntimeContext.getMetricGroup .addGroup("chronon") - .addGroup("feature_group", groupBy.getMetaData.getName) + .addGroup("feature_group", groupByName) evaluator.initialize(metricsGroup) } diff --git a/flink/src/main/scala/ai/chronon/flink/chaining/ChainedGroupByJob.scala b/flink/src/main/scala/ai/chronon/flink/chaining/ChainedGroupByJob.scala new file mode 100644 index 0000000000..3fec2f4c85 --- /dev/null +++ b/flink/src/main/scala/ai/chronon/flink/chaining/ChainedGroupByJob.scala @@ -0,0 +1,241 @@ +package ai.chronon.flink.chaining + +import ai.chronon.aggregator.windowing.ResolutionUtils +import ai.chronon.api.Extensions.GroupByOps +import ai.chronon.api.ScalaJavaConversions._ +import ai.chronon.api._ +import ai.chronon.flink.{AsyncKVStoreWriter, BaseFlinkJob, FlinkUtils, LateEventCounter, TiledAvroCodecFn} +import ai.chronon.flink.FlinkJob.watermarkStrategy +import ai.chronon.flink.deser.ProjectedEvent +import ai.chronon.flink.source.FlinkSource +import ai.chronon.flink.types.{AvroCodecOutput, TimestampedTile, WriteResponse} +import ai.chronon.flink.window.{ + AlwaysFireOnElementTrigger, + BufferedProcessingTimeTrigger, + FlinkRowAggProcessFunction, + FlinkRowAggregationFunction, + KeySelectorBuilder +} +import ai.chronon.online.{Api, GroupByServingInfoParsed, TopicInfo} + +import java.util.concurrent.TimeUnit +import scala.collection.Seq +import org.apache.flink.streaming.api.datastream.{AsyncDataStream, DataStream, SingleOutputStreamOperator} +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment +import org.apache.flink.streaming.api.functions.async.RichAsyncFunction +import org.apache.flink.streaming.api.windowing.assigners.{TumblingEventTimeWindows, WindowAssigner} +import org.apache.flink.streaming.api.windowing.time.Time +import org.apache.flink.streaming.api.windowing.triggers.Trigger +import org.apache.flink.streaming.api.windowing.windows.TimeWindow +import org.apache.flink.util.OutputTag + +/** Flink job implementation for chaining features using JoinSource GroupBys. + * The job reads from event source (that has already performed projections/filters), performs async enrichment, + * and another round of Spark expr eval before proceeding with tiled aggregations and writing to KV store. + * In case of errors during enrichment or query processing, the event is skipped and logged (to ensure we don't + * poison pill the Flink app) + */ +class ChainedGroupByJob(eventSrc: FlinkSource[ProjectedEvent], + inputSchema: Seq[(String, DataType)], + sinkFn: RichAsyncFunction[AvroCodecOutput, WriteResponse], + val groupByServingInfoParsed: GroupByServingInfoParsed, + parallelism: Int, + props: Map[String, String], + topicInfo: TopicInfo, + api: Api, + enableDebug: Boolean = false) + extends BaseFlinkJob { + + val groupByName: String = groupByServingInfoParsed.groupBy.getMetaData.getName + logger.info(f"Creating Flink JoinSource streaming job. groupByName=${groupByName}") + + // The source of our Flink application is a topic + val topic: String = topicInfo.name + + private val groupByConf = groupByServingInfoParsed.groupBy + + // Validate that this is a JoinSource configuration + require(groupByConf.streamingSource.isDefined, + s"No streaming source present in the groupBy: ${groupByConf.metaData.name}") + require(groupByConf.streamingSource.get.isSetJoinSource, + s"No JoinSource found in the groupBy: ${groupByConf.metaData.name}") + + val joinSource: JoinSource = groupByConf.streamingSource.get.getJoinSource + val leftSource: Source = joinSource.getJoin.getLeft + + // Validate Events-based source + require(leftSource.isSetEvents, + s"Only Events-based sources are currently supported. Found: ${leftSource.getSetField}") + + val keyColumns: Array[String] = groupByConf.keyColumns.toScala.toArray + val valueColumns: Array[String] = groupByConf.aggregationInputs + val eventTimeColumn = Constants.TimeColumn + + // Configuration properties with defaults + private val asyncTimeout: Long = FlinkUtils.getProperty("async_timeout_ms", props, topicInfo).getOrElse("5000").toLong + private val asyncCapacity: Int = FlinkUtils.getProperty("async_capacity", props, topicInfo).getOrElse("100").toInt + + // Configuration properties with defaults + private val kvStoreCapacity = FlinkUtils + .getProperty("kv_concurrency", props, topicInfo) + .map(_.toInt) + .getOrElse(AsyncKVStoreWriter.kvStoreConcurrency) + + // We default to the AlwaysFireOnElementTrigger which will cause the window to "FIRE" on every element. + // An alternative is the BufferedProcessingTimeTrigger (trigger=buffered in topic info + // or properties) which will buffer writes and only "FIRE" every X milliseconds per GroupBy & key. + private def getTrigger(): Trigger[ProjectedEvent, TimeWindow] = { + FlinkUtils.getProperty("trigger", props, topicInfo).getOrElse("always_fire") match { + case "always_fire" => new AlwaysFireOnElementTrigger() + case "buffered" => new BufferedProcessingTimeTrigger(100L) + case t => + throw new IllegalArgumentException(s"Unsupported trigger type: $t. Supported: 'always_fire', 'buffered'") + } + } + + /** Build the tiled version of the Flink GroupBy job that chains features using a JoinSource. + * The operators are structured as follows: + * - Source: Read from Kafka topic into ProjectedEvent stream + * - Assign timestamps and watermarks based on event time column + * - Async Enrichment: Use JoinEnrichmentAsyncFunction to fetch join data asynchronously + * - Join Source Query: Apply join source query transformations using JoinSourceQueryFunction + * - Avro Conversion: Convert enriched events to AvroCodecOutput format for KV + * - Sink: Write to KV store using AsyncKVStoreWriter + */ + override def runTiledGroupByJob(env: StreamExecutionEnvironment): DataStream[WriteResponse] = { + logger.info( + s"Building Flink streaming job for groupBy: $groupByName that chains join: ${joinSource.getJoin.getMetaData.getName}" + + s" using topic: $topic") + + // we expect parallelism on the source stream to be set by the source provider + val sourceSparkProjectedStream: DataStream[ProjectedEvent] = eventSrc + .getDataStream(topic, groupByName)(env, parallelism) + .uid(s"join-source-$groupByName") + .name(s"Join Source for $groupByName") + + val watermarkedStream = sourceSparkProjectedStream + .assignTimestampsAndWatermarks(watermarkStrategy) + .uid(s"join-source-watermarks-$groupByName") + .name(s"Spark expression eval with timestamps for $groupByName") + .setParallelism(sourceSparkProjectedStream.getParallelism) + + val enrichmentFunction = new JoinEnrichmentAsyncFunction( + joinSource.join.metaData.getName, + groupByName, + api, + enableDebug + ) + + val enrichedStream = AsyncDataStream + .unorderedWait( + watermarkedStream, + enrichmentFunction, + asyncTimeout, + TimeUnit.MILLISECONDS, + asyncCapacity + ) + .uid(s"join-enrichment-$groupByName") + .name(s"Async Join Enrichment for $groupByName") + .setParallelism(sourceSparkProjectedStream.getParallelism) + + // Apply join source query transformations only if there are transformations to apply + val processedStream = + if (joinSource.query != null && joinSource.query.selects != null && !joinSource.query.selects.isEmpty) { + logger.info("Applying join source query transformations") + val queryFunction = new JoinSourceQueryFunction( + joinSource, + inputSchema, + groupByName, + api, + enableDebug + ) + + enrichedStream + .flatMap(queryFunction) + .uid(s"join-source-query-$groupByName") + .name(s"Join Source Query for $groupByName") + .setParallelism(sourceSparkProjectedStream.getParallelism) + } else { + logger.info("No join source query transformations to apply - using enriched stream directly") + enrichedStream + } + + // Compute the output schema after JoinSourceQueryFunction transformations using Catalyst + val postTransformationSchema = computePostTransformationSchemaWithCatalyst(joinSource, inputSchema) + + // Calculate tiling window size based on the GroupBy configuration + val tilingWindowSizeInMillis: Long = + ResolutionUtils.getSmallestTailHopMillis(groupByServingInfoParsed.groupBy) + + // Configure tumbling window for tiled aggregations + val window = TumblingEventTimeWindows + .of(Time.milliseconds(tilingWindowSizeInMillis)) + .asInstanceOf[WindowAssigner[ProjectedEvent, TimeWindow]] + + // Configure trigger (default to always fire on element) + val trigger = getTrigger() + + // We use Flink "Side Outputs" to track any late events that aren't computed. + val tilingLateEventsTag = new OutputTag[ProjectedEvent]("tiling-late-events") {} + + // Tiled aggregation: key by entity keys, window, and aggregate + val tilingDS: SingleOutputStreamOperator[TimestampedTile] = + processedStream + .keyBy(KeySelectorBuilder.build(groupByServingInfoParsed.groupBy)) + .window(window) + .trigger(trigger) + .sideOutputLateData(tilingLateEventsTag) + .aggregate( + // Aggregation function that maintains incremental IRs in state + new FlinkRowAggregationFunction(groupByServingInfoParsed.groupBy, postTransformationSchema, enableDebug), + // Process function that marks tiles as closed for client-side caching + new FlinkRowAggProcessFunction(groupByServingInfoParsed.groupBy, postTransformationSchema, enableDebug) + ) + .uid(s"tiling-$groupByName") + .name(s"Tiling for $groupByName") + .setParallelism(sourceSparkProjectedStream.getParallelism) + + // Track late events + tilingDS + .getSideOutput(tilingLateEventsTag) + .flatMap(new LateEventCounter(groupByName)) + .uid(s"tiling-side-output-$groupByName") + .name(s"Tiling Side Output Late Data for $groupByName") + .setParallelism(sourceSparkProjectedStream.getParallelism) + + // Convert tiles to AvroCodecOutput format for KV store writing + val avroConvertedStream = tilingDS + .flatMap(TiledAvroCodecFn(groupByServingInfoParsed, tilingWindowSizeInMillis, enableDebug)) + .uid(s"avro-conversion-$groupByName") + .name(s"Avro Conversion for $groupByName") + .setParallelism(sourceSparkProjectedStream.getParallelism) + + // Write to KV store using existing AsyncKVStoreWriter + AsyncKVStoreWriter.withUnorderedWaits( + avroConvertedStream, + sinkFn, + groupByName, + capacity = kvStoreCapacity + ) + } + + /** Compute the schema that results after JoinSourceQueryFunction transformations. + * If there are no Query transforms defined in the JoinSource, we return the join schema + * (which includes the enrichment fields). Else, we get the output schema from CatalystUtil. + */ + private def computePostTransformationSchemaWithCatalyst( + joinSource: JoinSource, + originalInputSchema: Seq[(String, DataType)]): Seq[(String, DataType)] = { + if (joinSource.query == null || joinSource.query.selects == null || joinSource.query.selects.isEmpty) { + // No transformations applied, return join schema (includes enrichment) + val joinSchema = JoinSourceQueryFunction.buildJoinSchema(originalInputSchema, joinSource, api, enableDebug) + joinSchema.fields.map { field => + (field.name, field.fieldType) + }.toSeq + } else { + // Use shared method to determine the exact output schema + val result = JoinSourceQueryFunction.buildCatalystUtil(joinSource, originalInputSchema, api, enableDebug) + result.outputSchema + } + } +} diff --git a/flink/src/main/scala/ai/chronon/flink/chaining/JoinEnrichmentAsyncFunction.scala b/flink/src/main/scala/ai/chronon/flink/chaining/JoinEnrichmentAsyncFunction.scala new file mode 100644 index 0000000000..6a13f423f7 --- /dev/null +++ b/flink/src/main/scala/ai/chronon/flink/chaining/JoinEnrichmentAsyncFunction.scala @@ -0,0 +1,131 @@ +package ai.chronon.flink.chaining + +import ai.chronon.flink.DirectExecutionContext +import ai.chronon.flink.deser.ProjectedEvent +import ai.chronon.online.fetcher.Fetcher +import ai.chronon.online.Api +import org.apache.flink.configuration.Configuration +import org.apache.flink.dropwizard.metrics.DropwizardHistogramWrapper +import org.apache.flink.metrics.{Counter, Histogram} +import org.apache.flink.streaming.api.functions.async.{ResultFuture, RichAsyncFunction} +import org.slf4j.{Logger, LoggerFactory} + +import scala.concurrent.ExecutionContext +import scala.util.{Failure, Success} + +/** Async function for performing join enrichment on streaming data. + * + * This function takes ProjectedEvent objects (from the left source after query application) + * and enriches them with features from upstream joins, producing enriched ProjectedEvent objects + * that contain both original fields and joined features. + * + * @param joinRequestName The name of the join to fetch (format: "joins/join_name") + * @param api API implementation for fetcher access + * @param enableDebug Whether to enable debug logging + */ +class JoinEnrichmentAsyncFunction(joinRequestName: String, groupByName: String, api: Api, enableDebug: Boolean) + extends RichAsyncFunction[ProjectedEvent, ProjectedEvent] { + + @transient private lazy val logger: Logger = LoggerFactory.getLogger(getClass) + @transient private var fetcher: Fetcher = _ + @transient private var successCounter: Counter = _ + @transient private var errorCounter: Counter = _ + @transient private var notFoundCounter: Counter = _ + @transient private var joinFetchLatencyHistogram: Histogram = _ + + // The context used for the future callbacks + implicit lazy val ec: ExecutionContext = JoinEnrichmentAsyncFunction.ExecutionContextInstance + + override def open(parameters: Configuration): Unit = { + super.open(parameters) + + logger.info("Initializing Fetcher for JoinEnrichmentAsyncFunction") + fetcher = api.buildFetcher(debug = enableDebug) + + val group = getRuntimeContext.getMetricGroup + .addGroup("chronon") + .addGroup("group_by", groupByName) + .addGroup("join_enrichment", joinRequestName) + + successCounter = group.counter("join_fetch.successes") + errorCounter = group.counter("join_fetch.errors") + notFoundCounter = group.counter("join_fetch.not_found") + joinFetchLatencyHistogram = group.histogram( + "join_fetch_latency", + new DropwizardHistogramWrapper( + new com.codahale.metrics.Histogram(new com.codahale.metrics.ExponentiallyDecayingReservoir()) + ) + ) + logger.info(s"JoinEnrichmentAsyncFunction initialized for join: $joinRequestName") + } + + override def asyncInvoke(event: ProjectedEvent, resultFuture: ResultFuture[ProjectedEvent]): Unit = { + // Pass all left source field names to match Spark JoinSourceRunner approach + val scalaKeyMap: Map[String, AnyRef] = event.fields.map { case (k, v) => k -> v.asInstanceOf[AnyRef] }.toMap + // Create join request + val request = Fetcher.Request(joinRequestName, scalaKeyMap) + + if (enableDebug) { + logger.info(s"Join request: ${request.keys}, ts: ${request.atMillis}") + } + + // Start latency measurement + val startTime = System.currentTimeMillis() + + // Perform async join fetch + val future = fetcher.fetchJoin(Seq(request)) + + future.onComplete { + case Success(responses) => + // Record latency and increment success counter + joinFetchLatencyHistogram.update(System.currentTimeMillis() - startTime) + successCounter.inc() + + if (responses.nonEmpty) { + val response = responses.head + val responseMap = response.values.getOrElse(Map.empty[String, Any]) + val enrichedFields = event.fields ++ responseMap + + if (enableDebug) { + logger.info( + s"Join response: request=${response.request.keys}, " + + s"ts=${response.request.atMillis}, values=${response.values}") + } + + val enrichedEvent = ProjectedEvent( + enrichedFields, + event.startProcessingTimeMillis + ) + resultFuture.complete(java.util.Collections.singleton(enrichedEvent)) + } else { + // No join response, swallow the event and increment not found counter + notFoundCounter.inc() + resultFuture.complete(java.util.Collections.emptyList()) + } + + case Failure(ex) => + // Record latency and increment error counter + joinFetchLatencyHistogram.update(System.currentTimeMillis() - startTime) + errorCounter.inc() + + logger.error("Error fetching join data", ex) + // we swallow the event on error as there might be downstream join source queries dependent on the + // enrichment fields + resultFuture.complete(java.util.Collections.emptyList()) + } + } + + override def timeout(event: ProjectedEvent, resultFuture: ResultFuture[ProjectedEvent]): Unit = { + // Increment error counter for timeout + errorCounter.inc() + + logger.warn(s"Join enrichment timeout for event: ${event.fields}") + // we swallow the event on error as there might be downstream join source queries dependent on the + // enrichment fields + resultFuture.complete(java.util.Collections.emptyList()) + } +} + +object JoinEnrichmentAsyncFunction { + private val ExecutionContextInstance: ExecutionContext = new DirectExecutionContext +} diff --git a/flink/src/main/scala/ai/chronon/flink/chaining/JoinSourceQueryFunction.scala b/flink/src/main/scala/ai/chronon/flink/chaining/JoinSourceQueryFunction.scala new file mode 100644 index 0000000000..138a3224f3 --- /dev/null +++ b/flink/src/main/scala/ai/chronon/flink/chaining/JoinSourceQueryFunction.scala @@ -0,0 +1,178 @@ +package ai.chronon.flink.chaining + +import ai.chronon.api.{Constants, DataType, JoinSource, StructField, StructType} +import ai.chronon.flink.deser.ProjectedEvent +import ai.chronon.online.{Api, CatalystUtil, JoinCodec} +import ai.chronon.online.serde.SparkConversions +import org.apache.flink.api.common.functions.RichFlatMapFunction +import org.apache.flink.configuration.Configuration +import org.apache.flink.dropwizard.metrics.DropwizardHistogramWrapper +import org.apache.flink.metrics.{Counter, Histogram} +import org.apache.flink.util.Collector +import org.slf4j.{Logger, LoggerFactory} + +import scala.jdk.CollectionConverters._ + +/** Flink function for applying join source query transformations to enriched events. + * + * This function takes ProjectedEvent objects (after join enrichment) and applies + * the joinSource.query SQL expressions to produce the final processed events + * for aggregation and storage. + * + * @param joinSource The JoinSource configuration containing the query to apply + * @param inputSchema Schema of the left source (before enrichment) + * @param api API implementation for join codec access + * @param enableDebug Whether to enable debug logging + */ +class JoinSourceQueryFunction(joinSource: JoinSource, + inputSchema: Seq[(String, DataType)], + groupByName: String, + api: Api, + enableDebug: Boolean) + extends RichFlatMapFunction[ProjectedEvent, ProjectedEvent] { + + @transient private lazy val logger: Logger = LoggerFactory.getLogger(getClass) + @transient private var catalystUtil: CatalystUtil = _ + @transient private var successCounter: Counter = _ + @transient private var errorCounter: Counter = _ + @transient private var queryLatencyHistogram: Histogram = _ + + override def open(parameters: Configuration): Unit = { + super.open(parameters) + + logger.info("Initializing CatalystUtil for join source query evaluation") + + val result = JoinSourceQueryFunction.buildCatalystUtil(joinSource, inputSchema, api, enableDebug) + catalystUtil = result.catalystUtil + + val group = getRuntimeContext.getMetricGroup + .addGroup("chronon") + .addGroup("join_source", joinSource.join.metaData.getName) + .addGroup("group_by", groupByName) + + successCounter = group.counter("query_eval.successes") + errorCounter = group.counter("query_eval.errors") + queryLatencyHistogram = group.histogram( + "query_eval_latency", + new DropwizardHistogramWrapper( + new com.codahale.metrics.Histogram(new com.codahale.metrics.ExponentiallyDecayingReservoir()) + ) + ) + + logger.info(s"Initialized CatalystUtil with join schema") + } + + override def flatMap(enrichedEvent: ProjectedEvent, out: Collector[ProjectedEvent]): Unit = { + val startTime = System.currentTimeMillis() + try { + val queryResults = catalystUtil.performSql(enrichedEvent.fields) + + queryLatencyHistogram.update(System.currentTimeMillis() - startTime) + successCounter.inc() + + if (enableDebug) { + logger.info(s"Join source query input: ${enrichedEvent.fields}") + logger.info(s"Join source query results: $queryResults") + } + + // Output each result as a ProjectedEvent + queryResults.foreach { resultFields => + val resultEvent = ProjectedEvent(resultFields, enrichedEvent.startProcessingTimeMillis) + out.collect(resultEvent) + } + } catch { + case ex: Exception => + // we swallow the event on error + errorCounter.inc() + queryLatencyHistogram.update(System.currentTimeMillis() - startTime) + logger.error(s"Error applying join source query to event: ${enrichedEvent.fields}", ex) + } + } +} + +case class JoinSourceQueryResult( + catalystUtil: CatalystUtil, + joinSchema: StructType, + outputSchema: Seq[(String, DataType)] +) + +object JoinSourceQueryFunction { + private val logger: Logger = LoggerFactory.getLogger(getClass) + + def buildCatalystUtil( + joinSource: JoinSource, + inputSchema: Seq[(String, DataType)], + api: Api, + enableDebug: Boolean + ): JoinSourceQueryResult = { + + // Build join schema (leftSourceSchema + joinCodec.valueSchema) + val joinSchema = buildJoinSchema(inputSchema, joinSource, api, enableDebug) + + // Handle time column mapping like SparkExpressionEval.buildQueryTransformsAndFilters + val timeColumn = Option(joinSource.query.timeColumn).getOrElse(Constants.TimeColumn) + val rawSelects = joinSource.query.selects.asScala.toMap + + val timeColumnMapping = Map(Constants.TimeColumn -> timeColumn) // Add ts -> timeColumn mapping + val selectsWithTimeColumn = (rawSelects ++ timeColumnMapping).toSeq + val wheres = Option(joinSource.query.wheres).map(_.asScala).getOrElse(Seq.empty) + + // Create CatalystUtil instance + val catalystUtil = new CatalystUtil(joinSchema, selectsWithTimeColumn, wheres) + + // Get the output schema from Catalyst and convert to Chronon format + val outputSparkSchema = catalystUtil.getOutputSparkSchema + val outputSchema = outputSparkSchema.fields.map { field => + val chrononType = SparkConversions.toChrononType(field.name, field.dataType) + (field.name, chrononType) + }.toSeq + + logger.info(s""" + |Building CatalystUtil for join source query: + |Selects with time column: ${selectsWithTimeColumn} + |Wheres: ${wheres} + |Time column mapping: ${Constants.TimeColumn} -> ${timeColumn} + |Output schema: ${outputSchema.map { case (name, dataType) => s"$name: $dataType" }.mkString(", ")} + |""".stripMargin) + + JoinSourceQueryResult(catalystUtil, joinSchema, outputSchema) + } + + /** Build the join schema following JoinSourceRunner.buildSchemas approach: + * joinSchema = leftSourceSchema ++ joinCodec.valueSchema + */ + def buildJoinSchema( + inputSchema: Seq[(String, DataType)], + joinSource: JoinSource, + api: Api, + enableDebug: Boolean + ): StructType = { + // leftSourceSchema: Convert inputSchema to Chronon StructType + val leftSourceFields = inputSchema.map { case (name, dataType) => + StructField(name, dataType) + }.toArray + val leftSourceSchema = StructType("left_source", leftSourceFields) + + // joinCodec.valueSchema: Get schema of enriched fields from upstream join + val joinCodec: JoinCodec = api + .buildFetcher(debug = enableDebug) + .metadataStore + .buildJoinCodec(joinSource.getJoin, refreshOnFail = false) + + // joinSchema = leftSourceSchema ++ joinCodec.valueSchema + val joinFields = leftSourceSchema.fields ++ joinCodec.valueSchema.fields + val joinSchema = StructType("join_enriched", joinFields) + + logger.info(s""" + |Schema building for join source query: + |leftSourceSchema (${leftSourceFields.length} fields): + | ${leftSourceFields.map(f => s"${f.name}: ${f.fieldType}").mkString(", ")} + |joinCodec.valueSchema (${joinCodec.valueSchema.fields.length} fields): + | ${joinCodec.valueSchema.fields.map(f => s"${f.name}: ${f.fieldType}").mkString(", ")} + |joinSchema (${joinFields.length} fields): + | ${joinFields.map(f => s"${f.name}: ${f.fieldType}").mkString(", ")} + |""".stripMargin) + + joinSchema + } +} diff --git a/flink/src/main/scala/ai/chronon/flink/deser/ChrononDeserializationSchema.scala b/flink/src/main/scala/ai/chronon/flink/deser/ChrononDeserializationSchema.scala index bd42fbb901..204e52ea93 100644 --- a/flink/src/main/scala/ai/chronon/flink/deser/ChrononDeserializationSchema.scala +++ b/flink/src/main/scala/ai/chronon/flink/deser/ChrononDeserializationSchema.scala @@ -1,8 +1,10 @@ package ai.chronon.flink.deser import ai.chronon.api -import ai.chronon.api.GroupBy +import ai.chronon.api.{DataModel, GroupBy, Query} +import ai.chronon.api.Extensions.GroupByOps import ai.chronon.online.serde.SerDe +import ai.chronon.flink.SparkExpressionEval import org.apache.flink.api.common.serialization.AbstractDeserializationSchema import org.apache.spark.sql.{Encoder, Row} @@ -30,12 +32,31 @@ object DeserializationSchemaBuilder { def buildSourceIdentityDeserSchema(provider: SerDe, groupBy: GroupBy, enableDebug: Boolean = false): ChrononDeserializationSchema[Row] = { - new SourceIdentityDeserializationSchema(provider, groupBy, enableDebug) + new SourceIdentityDeserializationSchema(provider, groupBy.getMetaData.getName, enableDebug) + } + + def buildSourceIdentityDeserSchema(provider: SerDe, + groupByName: String, + enableDebug: Boolean): ChrononDeserializationSchema[Row] = { + new SourceIdentityDeserializationSchema(provider, groupByName, enableDebug) } def buildSourceProjectionDeserSchema(provider: SerDe, groupBy: GroupBy, enableDebug: Boolean = false): ChrononDeserializationSchema[ProjectedEvent] = { - new SourceProjectionDeserializationSchema(provider, groupBy, enableDebug) + val query = SparkExpressionEval.queryFromGroupBy(groupBy) + new SourceProjectionDeserializationSchema(provider, + query, + groupBy.getMetaData.getName, + groupBy.dataModel, + enableDebug) + } + + def buildSourceProjectionDeserSchema(provider: SerDe, + query: Query, + groupByName: String, + dataModel: DataModel, + enableDebug: Boolean): ChrononDeserializationSchema[ProjectedEvent] = { + new SourceProjectionDeserializationSchema(provider, query, groupByName, dataModel, enableDebug) } } diff --git a/flink/src/main/scala/ai/chronon/flink/deser/DeserializationSchema.scala b/flink/src/main/scala/ai/chronon/flink/deser/DeserializationSchema.scala index 3071f6b119..801ec37bfa 100644 --- a/flink/src/main/scala/ai/chronon/flink/deser/DeserializationSchema.scala +++ b/flink/src/main/scala/ai/chronon/flink/deser/DeserializationSchema.scala @@ -1,6 +1,6 @@ package ai.chronon.flink.deser -import ai.chronon.api.{DataType, GroupBy} +import ai.chronon.api.{DataModel, DataType, Query} import ai.chronon.flink.SparkExpressionEval import ai.chronon.online.serde.{Mutation, SerDe, SparkConversions} import com.codahale.metrics.ExponentiallyDecayingReservoir @@ -13,7 +13,9 @@ import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.{Encoder, Encoders, Row} import org.slf4j.{Logger, LoggerFactory} -abstract class BaseDeserializationSchema[T](deserSchemaProvider: SerDe, groupBy: GroupBy, enableDebug: Boolean = false) +abstract class BaseDeserializationSchema[T](deserSchemaProvider: SerDe, + groupByName: String, + enableDebug: Boolean = false) extends ChrononDeserializationSchema[T] { @transient lazy val logger: Logger = LoggerFactory.getLogger(getClass) @@ -31,7 +33,7 @@ abstract class BaseDeserializationSchema[T](deserSchemaProvider: SerDe, groupBy: super.open(context) val metricsGroup = context.getMetricGroup .addGroup("chronon") - .addGroup("group_by", groupBy.getMetaData.getName) + .addGroup("group_by", groupByName) deserializationErrorCounter = metricsGroup.counter("deserialization_errors") deserTimeHistogram = metricsGroup.histogram( "event_deser_time", @@ -65,8 +67,11 @@ abstract class BaseDeserializationSchema[T](deserSchemaProvider: SerDe, groupBy: } } -class SourceIdentityDeserializationSchema(deserSchemaProvider: SerDe, groupBy: GroupBy, enableDebug: Boolean = false) - extends BaseDeserializationSchema[Row](deserSchemaProvider, groupBy, enableDebug) { +/** Implementation of the Flink DeserializationSchema interface that handles deser of events without applying source + * projection. + */ +class SourceIdentityDeserializationSchema(deserSchemaProvider: SerDe, groupByName: String, enableDebug: Boolean = false) + extends BaseDeserializationSchema[Row](deserSchemaProvider, groupByName, enableDebug) { override def deserialize(messageBytes: Array[Byte], out: Collector[Row]): Unit = { val maybeMutation = doDeserializeMutation(messageBytes) @@ -89,8 +94,15 @@ class SourceIdentityDeserializationSchema(deserSchemaProvider: SerDe, groupBy: G */ case class ProjectedEvent(fields: Map[String, Any], startProcessingTimeMillis: Long) -class SourceProjectionDeserializationSchema(deserSchemaProvider: SerDe, groupBy: GroupBy, enableDebug: Boolean = false) - extends BaseDeserializationSchema[ProjectedEvent](deserSchemaProvider, groupBy, enableDebug) +/** Implementation of the Flink DeserializationSchema interface that handles deser of events along with applying source + * projection and filters using the provided Query on the source using SparkExpressionEval. + */ +class SourceProjectionDeserializationSchema(deserSchemaProvider: SerDe, + query: Query, + groupByName: String, + dataModel: DataModel, + enableDebug: Boolean = false) + extends BaseDeserializationSchema[ProjectedEvent](deserSchemaProvider, groupByName, enableDebug) with SourceProjection { @transient private var evaluator: SparkExpressionEval[Row] = _ @@ -100,7 +112,7 @@ class SourceProjectionDeserializationSchema(deserSchemaProvider: SerDe, groupBy: override def sourceProjectionEnabled: Boolean = true override def projectedSchema: Array[(String, DataType)] = { - val evaluator = new SparkExpressionEval[Row](sourceEventEncoder, groupBy) + val evaluator = new SparkExpressionEval[Row](sourceEventEncoder, query, groupByName, dataModel) evaluator.getOutputSchema.fields.map { field => (field.name, SparkConversions.toChrononType(field.name, field.dataType)) @@ -111,14 +123,14 @@ class SourceProjectionDeserializationSchema(deserSchemaProvider: SerDe, groupBy: super.open(context) val metricsGroup = context.getMetricGroup .addGroup("chronon") - .addGroup("feature_group", groupBy.getMetaData.getName) + .addGroup("feature_group", groupByName) performSqlErrorCounter = metricsGroup.counter("sql_exec_errors") // spark expr eval vars val eventExprEncoder = sourceEventEncoder.asInstanceOf[ExpressionEncoder[Row]] rowSerializer = eventExprEncoder.createSerializer() - evaluator = new SparkExpressionEval[Row](sourceEventEncoder, groupBy) + evaluator = new SparkExpressionEval[Row](sourceEventEncoder, query, groupByName, dataModel) evaluator.initialize(metricsGroup) } diff --git a/flink/src/main/scala/ai/chronon/flink/validation/ValidationFlinkJob.scala b/flink/src/main/scala/ai/chronon/flink/validation/ValidationFlinkJob.scala index 2a6419ac9b..e92762d7c4 100644 --- a/flink/src/main/scala/ai/chronon/flink/validation/ValidationFlinkJob.scala +++ b/flink/src/main/scala/ai/chronon/flink/validation/ValidationFlinkJob.scala @@ -2,7 +2,7 @@ package ai.chronon.flink.validation import ai.chronon.api.Extensions.{GroupByOps, SourceOps} import ai.chronon.flink.validation.SparkExprEvalComparisonFn.compareResultRows -import ai.chronon.flink.SparkExpressionEvalFn +import ai.chronon.flink.{SparkExpressionEval, SparkExpressionEvalFn} import ai.chronon.online.fetcher.MetadataStore import ai.chronon.online.{GroupByServingInfoParsed, TopicInfo} import org.apache.flink.api.common.typeinfo.TypeInformation @@ -129,10 +129,11 @@ class ValidationFlinkJob(eventSrc: FlinkSource[Row], .name(s"Source with ID for $groupByName") .setParallelism(sourceStream.getParallelism) // Use same parallelism as previous operator + val query = SparkExpressionEval.queryFromGroupBy(groupByServingInfoParsed.groupBy) sourceStreamWithId .countWindowAll(validationRows) - .apply( - new SparkDFVsCatalystComparisonFn(new SparkExpressionEvalFn[Row](encoder, groupByServingInfoParsed.groupBy))) + .apply(new SparkDFVsCatalystComparisonFn( + new SparkExpressionEvalFn[Row](encoder, query, groupByName, groupByServingInfoParsed.groupBy.dataModel))) .returns(TypeInformation.of(classOf[ValidationStats])) .uid(s"validation-stats-$groupByName") .name(s"Validation stats for $groupByName") diff --git a/flink/src/test/scala/ai/chronon/flink/chaining/ChainedGroupByJobIntegrationTest.scala b/flink/src/test/scala/ai/chronon/flink/chaining/ChainedGroupByJobIntegrationTest.scala new file mode 100644 index 0000000000..c9aab7e4bc --- /dev/null +++ b/flink/src/test/scala/ai/chronon/flink/chaining/ChainedGroupByJobIntegrationTest.scala @@ -0,0 +1,152 @@ +package ai.chronon.flink.chaining + +import ai.chronon.api._ +import ai.chronon.api.Extensions._ +import ai.chronon.flink.test.{CollectSink, FlinkTestUtils, MockAsyncKVStoreWriter} +import ai.chronon.flink.SparkExpressionEvalFn +import ai.chronon.online.serde.SparkConversions +import ai.chronon.online.Extensions.StructTypeOps +import ai.chronon.api.ScalaJavaConversions._ +import ai.chronon.online.{Api, GroupByServingInfoParsed, TopicInfo} +import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment +import org.apache.flink.test.util.MiniClusterWithClientResource +import org.apache.spark.sql.Encoders +import org.scalatest.BeforeAndAfter +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +import scala.collection.JavaConverters._ +import scala.collection.Seq + +class ChainedGroupByJobIntegrationTest extends AnyFlatSpec with BeforeAndAfter with Matchers { + import JoinTestUtils._ + + val flinkCluster = new MiniClusterWithClientResource( + new MiniClusterResourceConfiguration.Builder() + .setNumberSlotsPerTaskManager(8) + .setNumberTaskManagers(1) + .build) + + before { + flinkCluster.before() + CollectSink.values.clear() + } + + after { + flinkCluster.after() + CollectSink.values.clear() + } + + it should "run full FlinkJoinSourceJob pipeline with proper test implementations" in { + implicit val env: StreamExecutionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment + env.setParallelism(1) + + // Clear previous test results + CollectSink.values.clear() + + // Test input events + val elements = Seq( + JoinTestEvent("user1", "listing1", 50.0, 1699366993123L), + JoinTestEvent("user2", "listing2", 150.0, 1699366993124L), + JoinTestEvent("user3", "listing3", 25.0, 1699366993125L) + ) + + val testApi = new TestApi() + + val groupBy = buildJoinSourceTerminalGroupBy() + val (joinSourceJob, _) = buildFlinkJoinSourceJob(groupBy, elements, testApi) + + // Run the actual FlinkJoinSourceJob pipeline with our test implementations + val jobDataStream = joinSourceJob.runTiledGroupByJob(env) + jobDataStream.addSink(new CollectSink) + + // Execute the pipeline + println(s"Starting Flink execution...") + env.execute("FlinkJoinSourceJob Integration Test") + + // Verify outputs + val outputs = CollectSink.values.asScala + outputs should not be empty + + // Verify number of output events - should be as many as input events + outputs.size should be >= elements.size + + // Verify basic structure of all responses + outputs.foreach { response => + response should not be null + response.tsMillis should be > 0L + response.valueBytes should not be null + response.valueBytes.length should be > 0 + response.status should be (true) + response.dataset should not be null + } + + // check that the timestamps of the written out events match the input events + // we use a Set as we can have elements out of order given we have multiple tasks + val timestamps = outputs.map(_.tsMillis).toSet + val expectedTimestamps = elements.map(_.created).toSet + timestamps.intersect(expectedTimestamps) should not be empty + + // Extract and verify user IDs (keys contain dataset name + user_id in binary format) + val responseKeys = outputs.map(r => new String(r.keyBytes)).toSet + val expectedUsers = elements.map(_.user_id).toSet + + // Check that each expected user ID appears somewhere in at least one response key + expectedUsers.foreach { expectedUser => + responseKeys.exists(_.contains(expectedUser)) should be (true) + } + } + + private def buildFlinkJoinSourceJob(groupBy: GroupBy, elements: Seq[JoinTestEvent], api: Api): (ChainedGroupByJob, GroupByServingInfoParsed) = { + val joinSource = groupBy.streamingSource.get.getJoinSource + val parentJoin = joinSource.getJoin + + val leftSourceQuery = parentJoin.getLeft.query + val leftSourceGroupByName = s"left_source_${parentJoin.getMetaData.getName}" + + val sparkExpressionEvalFn = new SparkExpressionEvalFn(Encoders.product[JoinTestEvent], leftSourceQuery, leftSourceGroupByName) + val source = new MockJoinSource(elements, sparkExpressionEvalFn) + + // Schema after applying projection on the join's input topic + val inputSchemaDataTypes = Seq( + ("user_id", ai.chronon.api.StringType), + ("listing_id", ai.chronon.api.StringType), + ("price_discounted", ai.chronon.api.DoubleType), + ("ts", ai.chronon.api.LongType) + ) + + // we don't use the output schema in the Tiling implementation so we pass a dummy one + val groupByServingInfoParsed = + FlinkTestUtils.makeTestGroupByServingInfoParsed( + groupBy, + SparkConversions.fromChrononSchema(inputSchemaDataTypes), + SparkConversions.fromChrononSchema(inputSchemaDataTypes) + ) + + // Extract topic info from the join source + val leftSource = joinSource.getJoin.getLeft + val topicUri = leftSource.topic + val topicInfo = TopicInfo.parse(topicUri) + + val writerFn = new MockAsyncKVStoreWriter(Seq(true), api, groupBy.metaData.name) + val propsWithKafka = Map( + "bootstrap" -> "localhost:9092", + "trigger" -> "always_fire" // Explicitly use same trigger as working FlinkJobEventIntegrationTest + ) + + val flinkJob = new ChainedGroupByJob( + eventSrc = source, + inputSchema = inputSchemaDataTypes, + sinkFn = writerFn, + groupByServingInfoParsed = groupByServingInfoParsed, + parallelism = 2, + props = propsWithKafka, + topicInfo = topicInfo, + api = api, + enableDebug = true + ) + + (flinkJob, groupByServingInfoParsed) + } +} \ No newline at end of file diff --git a/flink/src/test/scala/ai/chronon/flink/chaining/JoinEnrichmentAsyncFunctionTest.scala b/flink/src/test/scala/ai/chronon/flink/chaining/JoinEnrichmentAsyncFunctionTest.scala new file mode 100644 index 0000000000..8bd9d951d6 --- /dev/null +++ b/flink/src/test/scala/ai/chronon/flink/chaining/JoinEnrichmentAsyncFunctionTest.scala @@ -0,0 +1,227 @@ +package ai.chronon.flink.chaining + +import ai.chronon.api.Constants +import ai.chronon.flink.deser.ProjectedEvent +import ai.chronon.online.fetcher.Fetcher +import ai.chronon.online.Api +import org.apache.flink.configuration.Configuration +import org.apache.flink.api.common.functions.RuntimeContext +import org.apache.flink.metrics.{Counter, Histogram, MetricGroup} +import org.apache.flink.metrics.groups.OperatorMetricGroup +import org.apache.flink.streaming.api.functions.async.ResultFuture +import org.mockito.ArgumentMatchers._ +import org.mockito.Mockito._ +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import org.scalatestplus.mockito.MockitoSugar + +import java.util.concurrent.{CountDownLatch, TimeUnit} +import scala.concurrent.Promise + +class JoinEnrichmentAsyncFunctionTest extends AnyFlatSpec with Matchers with MockitoSugar { + + val joinRequestName = "joins/test_team/test_join" + val enableDebug = false + + private def setupFunctionWithMockedMetrics(function: JoinEnrichmentAsyncFunction): Unit = { + // Mock the runtime context and metrics + val mockRuntimeContext = mock[RuntimeContext] + val mockOperatorMetricGroup = mock[OperatorMetricGroup] + val mockSubGroup = mock[MetricGroup] + val mockCounter = mock[Counter] + val mockNotFoundCounter = mock[Counter] + val mockHistogram = mock[Histogram] + + // Mock the metric group chain + when(mockRuntimeContext.getMetricGroup).thenReturn(mockOperatorMetricGroup) + when(mockOperatorMetricGroup.addGroup("chronon")).thenReturn(mockSubGroup) + when(mockSubGroup.addGroup(anyString(), anyString())).thenReturn(mockSubGroup) + when(mockSubGroup.counter("join_fetch.successes")).thenReturn(mockCounter) + when(mockSubGroup.counter("join_fetch.errors")).thenReturn(mockCounter) + when(mockSubGroup.counter("join_fetch.not_found")).thenReturn(mockNotFoundCounter) + when(mockSubGroup.histogram(anyString(), any())).thenReturn(mockHistogram) + + function.setRuntimeContext(mockRuntimeContext) + } + + "JoinEnrichmentAsyncFunction" should "enrich events with join response" in { + val mockApi = mock[Api] + val mockFetcher = mock[Fetcher] + + when(mockApi.buildFetcher(debug = enableDebug)).thenReturn(mockFetcher) + + // Create successful join response + val joinResponse = Fetcher.Response( + Fetcher.Request(joinRequestName, Map("user_id" -> "123"), Some(2000L)), + scala.util.Success(Map("user_category" -> "premium", "user_score" -> 85.5).asInstanceOf[Map[String, AnyRef]]) + ) + val joinFuture = Promise[Seq[Fetcher.Response]]() + joinFuture.success(Seq(joinResponse)) + when(mockFetcher.fetchJoin(any(), any())).thenReturn(joinFuture.future) + + val function = new JoinEnrichmentAsyncFunction(joinRequestName, "testGB", mockApi, enableDebug) + setupFunctionWithMockedMetrics(function) + function.open(new Configuration()) + + // Create test event + val eventFields = Map("user_id" -> "123", "price" -> 99.99, Constants.TimeColumn -> 1000L) + val event = ProjectedEvent(eventFields, 500L) + + // Mock result future + val latch = new CountDownLatch(1) + var result: ProjectedEvent = null + val resultFuture = new ResultFuture[ProjectedEvent] { + override def complete(results: java.util.Collection[ProjectedEvent]): Unit = { + result = results.iterator().next() + latch.countDown() + } + override def completeExceptionally(throwable: Throwable): Unit = { + throwable.printStackTrace() + latch.countDown() + } + } + + // Execute async invocation + function.asyncInvoke(event, resultFuture) + + // Wait for completion + latch.await(5, TimeUnit.SECONDS) should be(true) + + // Verify results + result should not be null + result.startProcessingTimeMillis should be(500L) + + val fields = result.fields + fields("user_id") should be("123") + fields("price") should be(99.99) + fields("user_category") should be("premium") + fields("user_score") should be(85.5) + + // Verify the join request was made + verify(mockFetcher).fetchJoin(any(), any()) + } + + it should "handle join timeout gracefully" in { + val mockApi = mock[Api] + val mockFetcher = mock[Fetcher] + + when(mockApi.buildFetcher(debug = enableDebug)).thenReturn(mockFetcher) + + val function = new JoinEnrichmentAsyncFunction(joinRequestName, "testGB", mockApi, enableDebug) + setupFunctionWithMockedMetrics(function) + function.open(new Configuration()) + + // Create test event + val eventFields = Map("user_id" -> "123", "price" -> 99.99, Constants.TimeColumn -> 1000L) + val event = ProjectedEvent(eventFields, 500L) + + // Mock result future + val latch = new CountDownLatch(1) + var results: java.util.Collection[ProjectedEvent] = null + val resultFuture = new ResultFuture[ProjectedEvent] { + override def complete(resultCollection: java.util.Collection[ProjectedEvent]): Unit = { + results = resultCollection + latch.countDown() + } + override def completeExceptionally(throwable: Throwable): Unit = { + latch.countDown() + } + } + + // Execute timeout + function.timeout(event, resultFuture) + + // Wait for completion + latch.await(5, TimeUnit.SECONDS) should be(true) + + // Verify timeout result (event is swallowed - empty collection returned) + results should not be null + results.isEmpty should be(true) + } + + it should "handle join failure gracefully" in { + val mockApi = mock[Api] + val mockFetcher = mock[Fetcher] + + when(mockApi.buildFetcher(debug = enableDebug)).thenReturn(mockFetcher) + + // Create failed join response + val joinFuture = Promise[Seq[Fetcher.Response]]() + joinFuture.failure(new RuntimeException("Join fetch failed")) + when(mockFetcher.fetchJoin(any(), any())).thenReturn(joinFuture.future) + + val function = new JoinEnrichmentAsyncFunction(joinRequestName, "testGB", mockApi, enableDebug) + setupFunctionWithMockedMetrics(function) + function.open(new Configuration()) + + // Create test event + val eventFields = Map("user_id" -> "123", "price" -> 99.99, Constants.TimeColumn -> 1000L) + val event = ProjectedEvent(eventFields, 500L) + + // Mock result future + val latch = new CountDownLatch(1) + var results: java.util.Collection[ProjectedEvent] = null + val resultFuture = new ResultFuture[ProjectedEvent] { + override def complete(resultCollection: java.util.Collection[ProjectedEvent]): Unit = { + results = resultCollection + latch.countDown() + } + override def completeExceptionally(throwable: Throwable): Unit = { + latch.countDown() + } + } + + // Execute async invocation + function.asyncInvoke(event, resultFuture) + + // Wait for completion + latch.await(5, TimeUnit.SECONDS) should be(true) + + // Verify failure result (event is swallowed - empty collection returned) + results should not be null + results.isEmpty should be(true) + } + + it should "handle empty join response" in { + val mockApi = mock[Api] + val mockFetcher = mock[Fetcher] + + when(mockApi.buildFetcher(debug = enableDebug)).thenReturn(mockFetcher) + + // Create empty join response + val joinFuture = Promise[Seq[Fetcher.Response]]() + joinFuture.success(Seq.empty) + when(mockFetcher.fetchJoin(any(), any())).thenReturn(joinFuture.future) + + val function = new JoinEnrichmentAsyncFunction(joinRequestName, "testGB", mockApi, enableDebug) + setupFunctionWithMockedMetrics(function) + function.open(new Configuration()) + + // Create test event + val eventFields = Map("user_id" -> "123", "price" -> 99.99, Constants.TimeColumn -> 1000L) + val event = ProjectedEvent(eventFields, 500L) + + // Mock result future + val latch = new CountDownLatch(1) + var results: java.util.Collection[ProjectedEvent] = null + val resultFuture = new ResultFuture[ProjectedEvent] { + override def complete(resultCollection: java.util.Collection[ProjectedEvent]): Unit = { + results = resultCollection + latch.countDown() + } + override def completeExceptionally(throwable: Throwable): Unit = { + latch.countDown() + } + } + + // Execute async invocation + function.asyncInvoke(event, resultFuture) + + // Wait for completion + latch.await(5, TimeUnit.SECONDS) should be(true) + + // Verify empty join response result (event is swallowed - empty collection returned) + results should not be null + results.isEmpty should be(true) + } +} \ No newline at end of file diff --git a/flink/src/test/scala/ai/chronon/flink/chaining/JoinSourceQueryFunctionTest.scala b/flink/src/test/scala/ai/chronon/flink/chaining/JoinSourceQueryFunctionTest.scala new file mode 100644 index 0000000000..ebf0893081 --- /dev/null +++ b/flink/src/test/scala/ai/chronon/flink/chaining/JoinSourceQueryFunctionTest.scala @@ -0,0 +1,230 @@ +package ai.chronon.flink.chaining + +import ai.chronon.api.{Builders, DoubleType, IntType, LongType, StringType} +import ai.chronon.flink.deser.ProjectedEvent +import ai.chronon.online.{Api, JoinCodec} +import ai.chronon.online.fetcher.Fetcher +import org.apache.flink.api.common.functions.RuntimeContext +import org.apache.flink.configuration.Configuration +import org.apache.flink.metrics.{Counter, Histogram, MetricGroup} +import org.apache.flink.metrics.groups.OperatorMetricGroup +import org.apache.flink.util.Collector +import org.mockito.ArgumentMatchers._ +import org.mockito.Mockito._ +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import org.scalatestplus.mockito.MockitoSugar + +import scala.collection.mutable.ListBuffer + +class JoinSourceQueryFunctionTest extends AnyFlatSpec with Matchers with MockitoSugar { + + val inputSchema = Seq( + ("user_id", StringType), + ("price", DoubleType), + ("timestamp", LongType) + ) + + private def setupFunctionWithMockedMetrics(function: JoinSourceQueryFunction): Unit = { + // Mock the runtime context and metrics + val mockRuntimeContext = mock[RuntimeContext] + val mockOperatorMetricGroup = mock[OperatorMetricGroup] + val mockSubGroup = mock[MetricGroup] + val mockCounter = mock[Counter] + val mockHistogram = mock[Histogram] + + // Mock the metric group chain + when(mockRuntimeContext.getMetricGroup).thenReturn(mockOperatorMetricGroup) + when(mockOperatorMetricGroup.addGroup("chronon")).thenReturn(mockSubGroup) + when(mockSubGroup.addGroup(anyString(), anyString())).thenReturn(mockSubGroup) + when(mockSubGroup.counter(anyString())).thenReturn(mockCounter) + when(mockSubGroup.histogram(anyString(), any())).thenReturn(mockHistogram) + + function.setRuntimeContext(mockRuntimeContext) + } + + "JoinSourceQueryFunction" should "apply SQL transformations correctly" in { + // Create a join source with SQL query + val parentJoin = Builders.Join( + left = Builders.Source.events( + query = Builders.Query( + selects = Map("user_id" -> "user_id", "price" -> "price"), + timeColumn = "timestamp" + ), + table = "test.events", + topic = "kafka://test-topic" + ), + joinParts = Seq(), + metaData = Builders.MetaData(name = "test.parent_join") + ) + + val joinSource = Builders.Source.joinSource( + join = parentJoin, + query = Builders.Query( + selects = Map( + "user_id" -> "user_id", + "doubled_price" -> "price * 2", + "price_category" -> "CASE WHEN price > 100 THEN 'expensive' ELSE 'affordable' END" + ), + timeColumn = "timestamp" + ) + ).getJoinSource + + val mockApi = mock[Api] + val mockFetcher = mock[Fetcher] + val mockMetadataStore = mock[ai.chronon.online.fetcher.MetadataStore] + val mockJoinCodec = mock[JoinCodec] + + when(mockApi.buildFetcher(debug = false)).thenReturn(mockFetcher) + when(mockFetcher.metadataStore).thenReturn(mockMetadataStore) + when(mockMetadataStore.buildJoinCodec(parentJoin, refreshOnFail = false)).thenReturn(mockJoinCodec) + + // Mock join codec schema (enriched fields from join) + val joinValueSchema = ai.chronon.api.StructType("join_enriched", Array( + ai.chronon.api.StructField("user_category", StringType), + ai.chronon.api.StructField("user_score", DoubleType) + )) + when(mockJoinCodec.valueSchema).thenReturn(joinValueSchema) + + val function = new JoinSourceQueryFunction(joinSource, inputSchema, groupByName = "testGB", mockApi, enableDebug = false) + setupFunctionWithMockedMetrics(function) + function.open(new Configuration()) + + // Create enriched event (after join processing) + val enrichedFields = Map( + "user_id" -> "123", + "price" -> 150.0, + "timestamp" -> 1000L, + "user_category" -> "premium", // From join + "user_score" -> 85.5 // From join + ) + val enrichedEvent = ProjectedEvent(enrichedFields, 500L) + + // Collect outputs + val outputs = ListBuffer[ProjectedEvent]() + val collector = new Collector[ProjectedEvent] { + override def collect(record: ProjectedEvent): Unit = outputs += record + override def close(): Unit = {} + } + + // Execute transformation + function.flatMap(enrichedEvent, collector) + + // Verify transformation results + outputs should have size 1 + val result = outputs.head + + result.startProcessingTimeMillis should be(500L) + + val resultFields = result.fields + resultFields("user_id") should be("123") + resultFields("doubled_price") should be(300.0) // 150 * 2 + resultFields("price_category") should be("expensive") // price > 100 + } + + + it should "handle query errors gracefully" in { + val parentJoin = Builders.Join( + left = Builders.Source.events( + query = Builders.Query(), + table = "test.events", + topic = "kafka://test-topic" + ), + joinParts = Seq(), + metaData = Builders.MetaData(name = "test.parent_join") + ) + + val joinSource = Builders.Source.joinSource( + join = parentJoin, + query = Builders.Query( + selects = Map("user_id" -> "user_id", "price_doubled" -> "price * 2"), + timeColumn = "timestamp" + ) + ).getJoinSource + + val mockApi = mock[Api] + val mockFetcher = mock[Fetcher] + val mockMetadataStore = mock[ai.chronon.online.fetcher.MetadataStore] + val mockJoinCodec = mock[JoinCodec] + + when(mockApi.buildFetcher()).thenReturn(mockFetcher) + when(mockFetcher.metadataStore).thenReturn(mockMetadataStore) + when(mockMetadataStore.buildJoinCodec(parentJoin, refreshOnFail = false)).thenReturn(mockJoinCodec) + + val joinValueSchema = ai.chronon.api.StructType("join_enriched", Array( + ai.chronon.api.StructField("user_category", StringType) + )) + when(mockJoinCodec.valueSchema).thenReturn(joinValueSchema) + + val function = new JoinSourceQueryFunction(joinSource, inputSchema, groupByName = "testGB", mockApi, enableDebug = false) + setupFunctionWithMockedMetrics(function) + function.open(new Configuration()) + + // Create enriched event with invalid data that will cause an exception when cast to wrong type + val enrichedFields = Map( + "user_id" -> "123", + "price" -> "invalid_number", // Throws a ClassCastException when processed as Double + "timestamp" -> 1000L, + "user_category" -> "premium" + ) + val enrichedEvent = ProjectedEvent(enrichedFields, 500L) + + // Collect outputs + val outputs = ListBuffer[ProjectedEvent]() + val collector = new Collector[ProjectedEvent] { + override def collect(record: ProjectedEvent): Unit = outputs += record + override def close(): Unit = {} + } + + // Execute (should handle error gracefully and swallow event) + function.flatMap(enrichedEvent, collector) + + // Should have swallowed the event on error - no outputs + outputs.isEmpty shouldBe true + } + + it should "build join schema correctly" in { + // Test that the schema building combines left source + join codec schemas + val parentJoin = Builders.Join( + left = Builders.Source.events( + query = Builders.Query(), + table = "test.events", + topic = "kafka://test-topic" + ), + joinParts = Seq(), + metaData = Builders.MetaData(name = "test.parent_join") + ) + + val joinSource = Builders.Source.joinSource( + join = parentJoin, + query = Builders.Query( + selects = Map("user_id" -> "user_id", "enriched_field" -> "enriched_field"), + timeColumn = "timestamp" + ) + ).getJoinSource + + val mockApi = mock[Api] + val mockFetcher = mock[Fetcher] + val mockMetadataStore = mock[ai.chronon.online.fetcher.MetadataStore] + val mockJoinCodec = mock[JoinCodec] + + when(mockApi.buildFetcher(debug = false)).thenReturn(mockFetcher) + when(mockFetcher.metadataStore).thenReturn(mockMetadataStore) + when(mockMetadataStore.buildJoinCodec(parentJoin, refreshOnFail = false)).thenReturn(mockJoinCodec) + + // Mock join codec value schema (fields added by join) + val joinValueSchema = ai.chronon.api.StructType("join_enriched", Array( + ai.chronon.api.StructField("enriched_field", StringType), + ai.chronon.api.StructField("another_enriched", IntType) + )) + when(mockJoinCodec.valueSchema).thenReturn(joinValueSchema) + + val function = new JoinSourceQueryFunction(joinSource, inputSchema, groupByName = "testGB", mockApi, enableDebug = false) + setupFunctionWithMockedMetrics(function) + + // Opening the function should trigger schema building without errors + function.open(new Configuration()) + + verify(mockJoinCodec, org.mockito.Mockito.atLeast(1)).valueSchema // Should have accessed the join schema + } +} \ No newline at end of file diff --git a/flink/src/test/scala/ai/chronon/flink/chaining/JoinTestUtils.scala b/flink/src/test/scala/ai/chronon/flink/chaining/JoinTestUtils.scala new file mode 100644 index 0000000000..ab7af5f7b1 --- /dev/null +++ b/flink/src/test/scala/ai/chronon/flink/chaining/JoinTestUtils.scala @@ -0,0 +1,223 @@ +package ai.chronon.flink.chaining + +import ai.chronon.api +import ai.chronon.api.{Accuracy, Builders, DoubleType, GroupBy, Join, Operation, StringType, StructField, StructType, TimeUnit, Window} +import ai.chronon.flink.SparkExpressionEvalFn +import ai.chronon.flink.deser.ProjectedEvent +import ai.chronon.flink.source.FlinkSource +import ai.chronon.online.KVStore.{GetRequest, GetResponse, PutRequest} +import ai.chronon.online.{Api, ExternalSourceRegistry, GroupByServingInfoParsed, JoinCodec, KVStore, LoggableResponse} +import ai.chronon.online.fetcher.{FetchContext, Fetcher, MetadataStore} +import ai.chronon.online.serde.SerDe +import org.apache.flink.api.common.eventtime.{SerializableTimestampAssigner, WatermarkStrategy} +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment + +import scala.collection.JavaConverters._ +import java.time.Duration +import scala.concurrent.Future +import scala.util.Success + +case class JoinTestEvent(user_id: String, listing_id: String, price: Double, created: Long) + +class MockJoinSource(mockEvents: Seq[JoinTestEvent], sparkExprEvalFn: SparkExpressionEvalFn[JoinTestEvent]) + extends FlinkSource[ProjectedEvent] { + + implicit val parallelism: Int = 1 + + override def getDataStream(topic: String, groupName: String)( + env: StreamExecutionEnvironment, + parallelism: Int): SingleOutputStreamOperator[ProjectedEvent] = { + + val watermarkStrategy = WatermarkStrategy + .forBoundedOutOfOrderness[JoinTestEvent](Duration.ofSeconds(5)) + .withTimestampAssigner(new SerializableTimestampAssigner[JoinTestEvent] { + override def extractTimestamp(event: JoinTestEvent, previousElementTimestamp: Long): Long = { + event.created + } + }) + + env + .fromCollection(mockEvents.asJava) + .assignTimestampsAndWatermarks(watermarkStrategy) + .flatMap(sparkExprEvalFn) + .map(e => ProjectedEvent(e, System.currentTimeMillis())) + } +} + +// Serializable implementations of Api, KVStore, Fetcher, and MetadataStore for testing purposes +// We primarily want to override the Fetcher and MetadataStore (for the join codec) +class TestApi extends Api(Map.empty) with Serializable { + + // Implement required abstract methods + override def externalRegistry: ExternalSourceRegistry = null + override def genMetricsKvStore(tableBaseName: String): KVStore = new TestKVStore() + override def logResponse(resp: LoggableResponse): Unit = {} + override def streamDecoder(groupByServingInfoParsed: GroupByServingInfoParsed): SerDe = null + + override def genKvStore: KVStore = new TestKVStore() + + override def buildFetcher(debug: Boolean = false, callerName: String = null, disableErrorThrows: Boolean = false): Fetcher = { + new TestFetcher(genKvStore) + } +} + +class TestKVStore extends KVStore with Serializable { + override def get(request: GetRequest): Future[GetResponse] = { + Future.successful(GetResponse(request, Success(Seq.empty))) + } + + override def multiGet(requests: Seq[GetRequest]): Future[Seq[GetResponse]] = { + Future.successful(requests.map(req => GetResponse(req, Success(Seq.empty)))) + } + + override def put(putRequest: PutRequest): Future[Boolean] = Future.successful(true) + override def multiPut(putRequests: Seq[KVStore.PutRequest]): Future[Seq[Boolean]] = Future.successful(putRequests.map(_ => true)) + override def create(dataset: String): Unit = {} + override def bulkPut(sourceOfflineTable: String, destinationOnlineDataSet: String, partition: String): Unit = {} +} + +class TestFetcher(kvStore: KVStore) extends Fetcher( + kvStore = kvStore, + metaDataSet = "test_metadata", + timeoutMillis = 5000L, + logFunc = null, + debug = false, + externalSourceRegistry = null, + callerName = "test", + flagStore = null, + disableErrorThrows = true, + executionContextOverride = null +) with Serializable { + + override val metadataStore: MetadataStore = new TestMetadataStore() + + override def fetchJoin(requests: Seq[Fetcher.Request], joinConf: Option[Join] = None): Future[Seq[Fetcher.Response]] = { + println(s"TestFetcher.fetchJoin called with ${requests.size} requests") + requests.foreach { request => + println(s"Join request: name=${request.name}, keys=${request.keys}, atMillis=${request.atMillis}") + } + + val responses = requests.map { request => + val listingId = request.keys.getOrElse("listing_id", "unknown").toString + + // Return enrichment values based on listing_id - parent join returns price_discounted_last for each listing + val enrichmentValues = Map[String, AnyRef]( + "price_discounted_last" -> (listingId match { + case "listing1" => 50.0.asInstanceOf[AnyRef] // Last price for listing1 + case "listing2" => 150.0.asInstanceOf[AnyRef] // Last price for listing2 + case "listing3" => 25.0.asInstanceOf[AnyRef] // Last price for listing3 + case _ => 50.0.asInstanceOf[AnyRef] + }) + ) + + println(s"Enrichment response for listing $listingId: $enrichmentValues") + Fetcher.Response(request, Success(enrichmentValues)) + } + + println(s"TestFetcher returning ${responses.size} responses") + Future.successful(responses) + } +} + +class TestMetadataStore extends MetadataStore(FetchContext(new TestKVStore(), null)) with Serializable { + override def buildJoinCodec(join: Join, refreshOnFail: Boolean = false): JoinCodec = TestJoinCodec.build(join) +} + +object TestJoinCodec { + import ai.chronon.api.Extensions.JoinOps + import ai.chronon.online.serde.AvroCodec + import ai.chronon.online.serde.AvroConversions + + def build(parentJoin: Join): JoinCodec = { + // Key schema: listing_id (from the GroupBy that the parent join queries) + val keyFields = Array(StructField("listing_id", StringType)) + val keySchema = StructType("parent_join_key", keyFields) + + // Value schema: only price_last (the feature from GroupBy 1) + val valueFields = Array( + StructField("price_discounted_last", DoubleType) // Feature from GroupBy 1 + ) + val valueSchema = StructType("parent_join_value", valueFields) + + val conf = new JoinOps(parentJoin) + + // Create proper Avro codecs using AvroConversions + val keyAvroSchema = AvroConversions.fromChrononSchema(keySchema) + val valueAvroSchema = AvroConversions.fromChrononSchema(valueSchema) + + val keyCodec = AvroCodec.of(keyAvroSchema.toString()) + val valueCodec = AvroCodec.of(valueAvroSchema.toString()) + + new JoinCodec(conf, keySchema, valueSchema, keyCodec, valueCodec, Array.empty, false) + } +} + +object JoinTestUtils { + // Build our chained GroupBy. + // We have: + // Source: listing_views (user_id, listing_id, price, created) + // GroupBy: listing_features (listing_id) -> last(price) + // Join: parentJoin JOIN listing_features ON listing_id + // GroupBy: user_price_features (user_id) -> last_k(price, k=10, window=1d) + def buildJoinSourceTerminalGroupBy(): api.GroupBy = { + val parentJoin = Builders.Join( + left = Builders.Source.events( + query = Builders.Query( + selects = Map("user_id" -> "user_id", "listing_id" -> "listing_id", "price_discounted" -> "price * 0.95"), + timeColumn = "created" + ), + table = "test.listing_views", + topic = "kafka://test-topic/serde=custom/provider_class=ai.chronon.flink.deser.MockCustomSchemaProvider/schema_name=item_event" + ), + joinParts = Seq( + // Mock join part representing listing enrichment + Builders.JoinPart( + groupBy = Builders.GroupBy( + sources = Seq(Builders.Source.events( + query = Builders.Query(), + table = "test.listings" + )), + keyColumns = Seq("listing_id"), + aggregations = Seq( + Builders.Aggregation( + operation = Operation.LAST, + inputColumn = "price_discounted" + ) + ), + metaData = Builders.MetaData(name = "test.listing_features") + ) + ) + ), + metaData = Builders.MetaData(name = "test.upstream_join") + ) + + // Create the chaining GroupBy that uses the upstream join as source + val chainGroupBy = Builders.GroupBy( + sources = Seq(Builders.Source.joinSource( + join = parentJoin, + query = + Builders.Query( + selects = Map( + "user_id" -> "user_id", + "listing_id" -> "listing_id", + "final_price" -> "price_discounted_last", + ), + ) + )), + keyColumns = Seq("user_id"), + aggregations = Seq( + Builders.Aggregation( + operation = Operation.LAST_K, + inputColumn = "final_price", // Use simple price field + argMap = Map("k" -> "10"), + windows = Seq(new Window(1, TimeUnit.DAYS)) + ) + ), + metaData = Builders.MetaData(name = "test.user_price_features"), + accuracy = Accuracy.TEMPORAL + ) + + chainGroupBy + } +} diff --git a/flink/src/test/scala/ai/chronon/flink/test/FlinkJobEntityIntegrationTest.scala b/flink/src/test/scala/ai/chronon/flink/test/FlinkJobEntityIntegrationTest.scala index 82f90bfef1..d71d07c6ec 100644 --- a/flink/src/test/scala/ai/chronon/flink/test/FlinkJobEntityIntegrationTest.scala +++ b/flink/src/test/scala/ai/chronon/flink/test/FlinkJobEntityIntegrationTest.scala @@ -1,9 +1,10 @@ package ai.chronon.flink.test import ai.chronon.api.Constants.{ReversalColumn, TimeColumn} +import ai.chronon.api.Extensions.GroupByOps import ai.chronon.api.ScalaJavaConversions._ import ai.chronon.api.GroupBy -import ai.chronon.flink.{FlinkJob, SparkExpressionEval, SparkExpressionEvalFn} +import ai.chronon.flink.{FlinkGroupByStreamingJob, SparkExpressionEval, SparkExpressionEvalFn} import ai.chronon.online.serde.SparkConversions import ai.chronon.online.{Api, GroupByServingInfoParsed, TopicInfo} import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration @@ -76,13 +77,14 @@ class FlinkJobEntityIntegrationTest extends AnyFlatSpec with BeforeAndAfter { } private def buildFlinkJob(groupBy: GroupBy, - elements: Seq[E2ETestMutationEvent]): (FlinkJob, GroupByServingInfoParsed) = { - val sparkExpressionEvalFn = new SparkExpressionEvalFn(Encoders.product[E2ETestMutationEvent], groupBy) + elements: Seq[E2ETestMutationEvent]): (FlinkGroupByStreamingJob, GroupByServingInfoParsed) = { + val query = SparkExpressionEval.queryFromGroupBy(groupBy) + val sparkExpressionEvalFn = new SparkExpressionEvalFn(Encoders.product[E2ETestMutationEvent], query, groupBy.metaData.name, groupBy.dataModel) val source = new WatermarkedE2ETestMutationEventSource(elements, sparkExpressionEvalFn) // Prepare the Flink Job val encoder = Encoders.product[E2ETestMutationEvent] - val outputSchema = new SparkExpressionEval(encoder, groupBy).getOutputSchema + val outputSchema = new SparkExpressionEval(encoder, query, groupBy.getMetaData.getName, groupBy.dataModel).getOutputSchema val outputSchemaDataTypes = outputSchema.fields.map { field => (field.name, SparkConversions.toChrononType(field.name, field.dataType)) } @@ -92,7 +94,7 @@ class FlinkJobEntityIntegrationTest extends AnyFlatSpec with BeforeAndAfter { val mockApi = mock[Api](withSettings().serializable()) val writerFn = new MockAsyncKVStoreWriter(Seq(true), mockApi, groupBy.metaData.name) val topicInfo = TopicInfo.parse("kafka://test-topic") - (new FlinkJob(source, + (new FlinkGroupByStreamingJob(source, outputSchemaDataTypes, writerFn, groupByServingInfoParsed, diff --git a/flink/src/test/scala/ai/chronon/flink/test/FlinkJobEventIntegrationTest.scala b/flink/src/test/scala/ai/chronon/flink/test/FlinkJobEventIntegrationTest.scala index 4a5a1d6042..ce8e48d9e2 100644 --- a/flink/src/test/scala/ai/chronon/flink/test/FlinkJobEventIntegrationTest.scala +++ b/flink/src/test/scala/ai/chronon/flink/test/FlinkJobEventIntegrationTest.scala @@ -1,8 +1,9 @@ package ai.chronon.flink.test +import ai.chronon.api.Extensions.GroupByOps import ai.chronon.api.{GroupBy, TilingUtils} import ai.chronon.api.ScalaJavaConversions._ -import ai.chronon.flink.{FlinkJob, SparkExpressionEval, SparkExpressionEvalFn} +import ai.chronon.flink.{FlinkGroupByStreamingJob, SparkExpressionEval, SparkExpressionEvalFn} import ai.chronon.flink.types.TimestampedIR import ai.chronon.flink.types.TimestampedTile import ai.chronon.flink.types.WriteResponse @@ -153,13 +154,14 @@ class FlinkJobEventIntegrationTest extends AnyFlatSpec with BeforeAndAfter { expectedFinalIRsPerKey shouldBe finalIRsPerKey } - private def buildFlinkJob(groupBy: GroupBy, elements: Seq[E2ETestEvent]): (FlinkJob, GroupByServingInfoParsed) = { - val sparkExpressionEvalFn = new SparkExpressionEvalFn(Encoders.product[E2ETestEvent], groupBy) + private def buildFlinkJob(groupBy: GroupBy, elements: Seq[E2ETestEvent]): (FlinkGroupByStreamingJob, GroupByServingInfoParsed) = { + val query = SparkExpressionEval.queryFromGroupBy(groupBy) + val sparkExpressionEvalFn = new SparkExpressionEvalFn(Encoders.product[E2ETestEvent], query, groupBy.metaData.name, groupBy.dataModel) val source = new WatermarkedE2EEventSource(elements, sparkExpressionEvalFn) // Prepare the Flink Job val encoder = Encoders.product[E2ETestEvent] - val outputSchema = new SparkExpressionEval(encoder, groupBy).getOutputSchema + val outputSchema = new SparkExpressionEval(encoder, query, groupBy.getMetaData.getName, groupBy.dataModel).getOutputSchema val outputSchemaDataTypes = outputSchema.fields.map { field => (field.name, SparkConversions.toChrononType(field.name, field.dataType)) } @@ -169,7 +171,7 @@ class FlinkJobEventIntegrationTest extends AnyFlatSpec with BeforeAndAfter { val mockApi = mock[Api](withSettings().serializable()) val writerFn = new MockAsyncKVStoreWriter(Seq(true), mockApi, groupBy.metaData.name) val topicInfo = TopicInfo.parse("kafka://test-topic") - (new FlinkJob(source, + (new FlinkGroupByStreamingJob(source, outputSchemaDataTypes, writerFn, groupByServingInfoParsed, diff --git a/flink/src/test/scala/ai/chronon/flink/test/SparkExpressionEvalFnTest.scala b/flink/src/test/scala/ai/chronon/flink/test/SparkExpressionEvalFnTest.scala index 21b6f6dbdd..b7d80f7515 100644 --- a/flink/src/test/scala/ai/chronon/flink/test/SparkExpressionEvalFnTest.scala +++ b/flink/src/test/scala/ai/chronon/flink/test/SparkExpressionEvalFnTest.scala @@ -1,8 +1,9 @@ package ai.chronon.flink.test +import ai.chronon.api.Extensions.GroupByOps import ai.chronon.api.ScalaJavaConversions.IteratorOps import ai.chronon.api.ScalaJavaConversions.JListOps -import ai.chronon.flink.SparkExpressionEvalFn +import ai.chronon.flink.{SparkExpressionEval, SparkExpressionEvalFn} import org.apache.flink.streaming.api.datastream.DataStream import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment import org.apache.spark.sql.Encoders @@ -19,11 +20,13 @@ class SparkExpressionEvalFnTest extends AnyFlatSpec { ) val groupBy = FlinkTestUtils.makeGroupBy(Seq("id")) + val query = SparkExpressionEval.queryFromGroupBy(groupBy) val encoder = Encoders.product[E2ETestEvent] val sparkExprEval = new SparkExpressionEvalFn[E2ETestEvent]( encoder, - groupBy + query, + groupBy.metaData.name ) val env = StreamExecutionEnvironment.getExecutionEnvironment @@ -47,12 +50,14 @@ class SparkExpressionEvalFnTest extends AnyFlatSpec { ) val groupBy = FlinkTestUtils.makeGroupBy(Seq("id"), filters = null) + val query = SparkExpressionEval.queryFromGroupBy(groupBy) val encoder = Encoders.product[E2ETestEvent] val sparkExprEval = new SparkExpressionEvalFn[E2ETestEvent]( encoder, - groupBy + query, + groupBy.metaData.name ) val env = StreamExecutionEnvironment.getExecutionEnvironment @@ -73,11 +78,14 @@ class SparkExpressionEvalFnTest extends AnyFlatSpec { ) val groupBy = FlinkTestUtils.makeEntityGroupBy(Seq("id")) + val query = SparkExpressionEval.queryFromGroupBy(groupBy) val encoder = Encoders.product[E2ETestMutationEvent] val sparkExprEval = new SparkExpressionEvalFn[E2ETestMutationEvent]( encoder, - groupBy + query, + groupBy.metaData.name, + groupBy.dataModel ) val env = StreamExecutionEnvironment.getExecutionEnvironment diff --git a/flink/src/test/scala/ai/chronon/flink/test/deser/CatalystUtilComplexAvroTest.scala b/flink/src/test/scala/ai/chronon/flink/test/deser/CatalystUtilComplexAvroTest.scala index 5e56d978a6..be8d7c63a9 100644 --- a/flink/src/test/scala/ai/chronon/flink/test/deser/CatalystUtilComplexAvroTest.scala +++ b/flink/src/test/scala/ai/chronon/flink/test/deser/CatalystUtilComplexAvroTest.scala @@ -1,7 +1,7 @@ package ai.chronon.flink.test.deser import ai.chronon.api.{Accuracy, Builders, GroupBy} -import ai.chronon.flink.deser.{ProjectedEvent, SourceProjectionDeserializationSchema} +import ai.chronon.flink.deser.{DeserializationSchemaBuilder, ProjectedEvent, SourceProjectionDeserializationSchema} import ai.chronon.online.serde.{AvroCodec, AvroSerDe} import org.apache.avro.Schema import org.apache.avro.generic.{GenericData, GenericRecord} @@ -65,7 +65,7 @@ class CatalystUtilComplexAvroTest extends AnyFlatSpec { val avroSchema = AvroCodec.of(testSchema.toString).schema val avroSerDe = new AvroSerDe(avroSchema) - val deserSchema = new SourceProjectionDeserializationSchema(avroSerDe, testGroupBy) + val deserSchema = DeserializationSchemaBuilder.buildSourceProjectionDeserSchema(avroSerDe, testGroupBy) deserSchema.open(new DummyInitializationContext) val resultList = new util.ArrayList[ProjectedEvent]() val listCollector = new ListCollector(resultList) diff --git a/flink/src/test/scala/ai/chronon/flink/validation/ValidationFlinkJobIntegrationTest.scala b/flink/src/test/scala/ai/chronon/flink/validation/ValidationFlinkJobIntegrationTest.scala index 608c1588d6..81e2cd5d5b 100644 --- a/flink/src/test/scala/ai/chronon/flink/validation/ValidationFlinkJobIntegrationTest.scala +++ b/flink/src/test/scala/ai/chronon/flink/validation/ValidationFlinkJobIntegrationTest.scala @@ -1,5 +1,6 @@ package ai.chronon.flink.validation +import ai.chronon.api.Extensions.GroupByOps import ai.chronon.api.ScalaJavaConversions._ import ai.chronon.flink.source.FlinkSource import ai.chronon.flink.test.{CollectSink, FlinkTestUtils} @@ -74,7 +75,8 @@ class ValidationFlinkJobIntegrationTest extends AnyFlatSpec with BeforeAndAfter StructField("created", LongType)) val encoder = Encoders.row(StructType(fields)) - val outputSchema = new SparkExpressionEval(encoder, groupBy).getOutputSchema + val query = SparkExpressionEval.queryFromGroupBy(groupBy) + val outputSchema = new SparkExpressionEval(encoder, query, groupBy.getMetaData.getName, groupBy.dataModel).getOutputSchema val groupByServingInfoParsed = FlinkTestUtils.makeTestGroupByServingInfoParsed(groupBy, encoder.schema, outputSchema) diff --git a/online/src/main/scala/ai/chronon/online/Api.scala b/online/src/main/scala/ai/chronon/online/Api.scala index 28fd48b721..42b032bf58 100644 --- a/online/src/main/scala/ai/chronon/online/Api.scala +++ b/online/src/main/scala/ai/chronon/online/Api.scala @@ -255,9 +255,7 @@ abstract class Api(userConf: Map[String, String]) extends Serializable { // not sure if thread safe - TODO: double check // helper functions - final def buildFetcher(debug: Boolean = false, - callerName: String = null, - disableErrorThrows: Boolean = false): Fetcher = + def buildFetcher(debug: Boolean = false, callerName: String = null, disableErrorThrows: Boolean = false): Fetcher = new Fetcher( genKvStore, Constants.MetadataDataset,