Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions flink/src/main/scala/ai/chronon/flink/AvroCodecFn.scala
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,8 @@ case class AvroCodecFn[T](groupByServingInfoParsed: GroupByServingInfoParsed)
*
* @param groupByServingInfoParsed The GroupBy we are working with
* @param tilingWindowSizeMs The size of the tiling window in milliseconds
* @tparam T The input data type
*/
case class TiledAvroCodecFn[T](groupByServingInfoParsed: GroupByServingInfoParsed, tilingWindowSizeMs: Long)
case class TiledAvroCodecFn(groupByServingInfoParsed: GroupByServingInfoParsed, tilingWindowSizeMs: Long)
extends BaseAvroCodecFn[TimestampedTile, AvroCodecOutput] {
override def open(configuration: Configuration): Unit = {
super.open(configuration)
Expand Down
213 changes: 72 additions & 141 deletions flink/src/main/scala/ai/chronon/flink/FlinkJob.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@ import ai.chronon.api.Extensions.GroupByOps
import ai.chronon.api.Extensions.SourceOps
import ai.chronon.api.ScalaJavaConversions._
import ai.chronon.flink.FlinkJob.watermarkStrategy
import ai.chronon.flink.SchemaRegistrySchemaProvider.RegistryHostKey
import ai.chronon.flink.SourceIdentitySchemaRegistrySchemaProvider.RegistryHostKey
import ai.chronon.flink.types.AvroCodecOutput
import ai.chronon.flink.types.TimestampedTile
import ai.chronon.flink.types.WriteResponse
import ai.chronon.flink.validation.ValidationFlinkJob
import ai.chronon.flink.window.AlwaysFireOnElementTrigger
import ai.chronon.flink.window.FlinkRowAggProcessFunction
import ai.chronon.flink.window.FlinkRowAggregationFunction
import ai.chronon.flink.window.KeySelectorBuilder
import ai.chronon.flink.window.{
AlwaysFireOnElementTrigger,
FlinkRowAggProcessFunction,
FlinkRowAggregationFunction,
KeySelectorBuilder
}
import ai.chronon.online.Api
import ai.chronon.online.FlagStoreConstants
import ai.chronon.online.GroupByServingInfoParsed
import ai.chronon.online.SparkConversions
import ai.chronon.online.TopicInfo
import ai.chronon.online.fetcher.{FetchContext, MetadataStore}
import org.apache.flink.api.common.eventtime.SerializableTimestampAssigner
Expand All @@ -39,7 +39,6 @@ import org.apache.flink.streaming.api.windowing.assigners.WindowAssigner
import org.apache.flink.streaming.api.windowing.time.Time
import org.apache.flink.streaming.api.windowing.windows.TimeWindow
import org.apache.flink.util.OutputTag
import org.apache.spark.sql.Encoder
import org.rogach.scallop.ScallopConf
import org.rogach.scallop.ScallopOption
import org.rogach.scallop.Serialization
Expand All @@ -50,31 +49,24 @@ import scala.concurrent.duration.DurationInt
import scala.concurrent.duration.FiniteDuration
import scala.collection.Seq

/** Flink job that processes a single streaming GroupBy and writes out the results to the KV store.
*
* There are two versions of the job, tiled and untiled. The untiled version writes out raw events while the tiled
* version writes out pre-aggregates. See the `runGroupByJob` and `runTiledGroupByJob` methods for more details.
/** Flink job that processes a single streaming GroupBy and writes out the results (in the form of pre-aggregated tiles) to the KV store.
*
* @param eventSrc - Provider of a Flink Datastream[T] for the given topic and groupBy
* @param eventSrc - Provider of a Flink Datastream[ Map[String, Any] ] for the given topic and groupBy. The Map contains
* projected columns from the source data based on projections and filters in the GroupBy.
* @param sinkFn - Async Flink writer function to help us write to the KV store
* @param groupByServingInfoParsed - The GroupBy we are working with
* @param encoder - Spark Encoder for the input data type
* @param parallelism - Parallelism to use for the Flink job
* @tparam T - The input data type
*/
class FlinkJob[T](eventSrc: FlinkSource[T],
sinkFn: RichAsyncFunction[AvroCodecOutput, WriteResponse],
groupByServingInfoParsed: GroupByServingInfoParsed,
encoder: Encoder[T],
parallelism: Int) {
class FlinkJob(eventSrc: FlinkSource[Map[String, Any]],
inputSchema: Seq[(String, DataType)],
sinkFn: RichAsyncFunction[AvroCodecOutput, WriteResponse],
groupByServingInfoParsed: GroupByServingInfoParsed,
parallelism: Int) {
private[this] val logger = LoggerFactory.getLogger(getClass)

val groupByName: String = groupByServingInfoParsed.groupBy.getMetaData.getName
logger.info(f"Creating Flink job. groupByName=${groupByName}")

protected val exprEval: SparkExpressionEvalFn[T] =
new SparkExpressionEvalFn[T](encoder, groupByServingInfoParsed.groupBy)

if (groupByServingInfoParsed.groupBy.streamingSource.isEmpty) {
throw new IllegalArgumentException(
s"Invalid groupBy: $groupByName. No streaming source"
Expand All @@ -84,57 +76,6 @@ class FlinkJob[T](eventSrc: FlinkSource[T],
// The source of our Flink application is a topic
val topic: String = groupByServingInfoParsed.groupBy.streamingSource.get.topic

/** The "untiled" version of the Flink app.
*
* At a high level, the operators are structured as follows:
* source -> Spark expression eval -> Avro conversion -> KV store writer
* source - Reads objects of type T (specific case class, Thrift / Proto) from a topic
* Spark expression eval - Evaluates the Spark SQL expression in the GroupBy and projects and filters the input data
* Avro conversion - Converts the Spark expr eval output to a form that can be written out to the KV store
* (PutRequest object)
* KV store writer - Writes the PutRequest objects to the KV store using the AsyncDataStream API
*
* In this untiled version, there are no shuffles and thus this ends up being a single node in the Flink DAG
* (with the above 4 operators and parallelism as injected by the user).
*/
def runGroupByJob(env: StreamExecutionEnvironment): DataStream[WriteResponse] = {

logger.info(
f"Running Flink job for groupByName=${groupByName}, Topic=${topic}. " +
"Tiling is disabled.")

// we expect parallelism on the source stream to be set by the source provider
val sourceStream: DataStream[T] =
eventSrc
.getDataStream(topic, groupByName)(env, parallelism)
.uid(s"source-$groupByName")
.name(s"Source for $groupByName")

val sparkExprEvalDS: DataStream[Map[String, Any]] = sourceStream
.flatMap(exprEval)
.uid(s"spark-expr-eval-flatmap-$groupByName")
.name(s"Spark expression eval for $groupByName")
.setParallelism(sourceStream.getParallelism) // Use same parallelism as previous operator

val sparkExprEvalDSWithWatermarks: DataStream[Map[String, Any]] = sparkExprEvalDS
.assignTimestampsAndWatermarks(watermarkStrategy)
.uid(s"spark-expr-eval-timestamps-$groupByName")
.name(s"Spark expression eval with timestamps for $groupByName")
.setParallelism(sourceStream.getParallelism)

val putRecordDS: DataStream[AvroCodecOutput] = sparkExprEvalDSWithWatermarks
.flatMap(AvroCodecFn[T](groupByServingInfoParsed))
.uid(s"avro-conversion-$groupByName")
.name(s"Avro conversion for $groupByName")
.setParallelism(sourceStream.getParallelism)

AsyncKVStoreWriter.withUnorderedWaits(
putRecordDS,
sinkFn,
groupByName
)
}

/** The "tiled" version of the Flink app.
*
* The operators are structured as follows:
Expand All @@ -158,28 +99,17 @@ class FlinkJob[T](eventSrc: FlinkSource[T],
ResolutionUtils.getSmallestWindowResolutionInMillis(groupByServingInfoParsed.groupBy)

// we expect parallelism on the source stream to be set by the source provider
val sourceStream: DataStream[T] =
val sourceSparkProjectedStream: DataStream[Map[String, Any]] =
eventSrc
.getDataStream(topic, groupByName)(env, parallelism)
.uid(s"source-$groupByName")
.name(s"Source for $groupByName")

val sparkExprEvalDS: DataStream[Map[String, Any]] = sourceStream
.flatMap(exprEval)
.uid(s"spark-expr-eval-flatmap-$groupByName")
.name(s"Spark expression eval for $groupByName")
.setParallelism(sourceStream.getParallelism) // Use same parallelism as previous operator

val sparkExprEvalDSAndWatermarks: DataStream[Map[String, Any]] = sparkExprEvalDS
val sparkExprEvalDSAndWatermarks: DataStream[Map[String, Any]] = sourceSparkProjectedStream
.assignTimestampsAndWatermarks(watermarkStrategy)
.uid(s"spark-expr-eval-timestamps-$groupByName")
.name(s"Spark expression eval with timestamps for $groupByName")
.setParallelism(sourceStream.getParallelism)

val inputSchema: Seq[(String, DataType)] =
exprEval.getOutputSchema.fields
.map(field => (field.name, SparkConversions.toChrononType(field.name, field.dataType)))
.toSeq
.setParallelism(sourceSparkProjectedStream.getParallelism)

val window = TumblingEventTimeWindows
.of(Time.milliseconds(tilingWindowSizeInMillis))
Expand Down Expand Up @@ -216,21 +146,21 @@ class FlinkJob[T](eventSrc: FlinkSource[T],
)
.uid(s"tiling-01-$groupByName")
.name(s"Tiling for $groupByName")
.setParallelism(sourceStream.getParallelism)
.setParallelism(sourceSparkProjectedStream.getParallelism)

// Track late events
tilingDS
.getSideOutput(tilingLateEventsTag)
.flatMap(new LateEventCounter(groupByName))
.uid(s"tiling-side-output-01-$groupByName")
.name(s"Tiling Side Output Late Data for $groupByName")
.setParallelism(sourceStream.getParallelism)
.setParallelism(sourceSparkProjectedStream.getParallelism)

val putRecordDS: DataStream[AvroCodecOutput] = tilingDS
.flatMap(new TiledAvroCodecFn[T](groupByServingInfoParsed, tilingWindowSizeInMillis))
.flatMap(TiledAvroCodecFn(groupByServingInfoParsed, tilingWindowSizeInMillis))
.uid(s"avro-conversion-01-$groupByName")
.name(s"Avro conversion for $groupByName")
.setParallelism(sourceStream.getParallelism)
.setParallelism(sourceSparkProjectedStream.getParallelism)

AsyncKVStoreWriter.withUnorderedWaits(
putRecordDS,
Expand Down Expand Up @@ -309,7 +239,6 @@ object FlinkJob {
val groupByName = jobArgs.groupbyName()
val onlineClassName = jobArgs.onlineClass()
val props = jobArgs.apiProps.map(identity)
val useMockedSource = jobArgs.mockSource()
val kafkaBootstrap = jobArgs.kafkaBootstrap.toOption
val validateMode = jobArgs.validate()
val validateRows = jobArgs.validateRows()
Expand All @@ -328,48 +257,16 @@ object FlinkJob {
}
}

val maybeServingInfo = metadataStore.getGroupByServingInfo(groupByName)
val flinkJob =
if (useMockedSource) {
// We will yank this conditional block when we wire up our real sources etc.
TestFlinkJob.buildTestFlinkJob(api)
} else {
val maybeServingInfo = metadataStore.getGroupByServingInfo(groupByName)
maybeServingInfo
.map { servingInfo =>
val topicUri = servingInfo.groupBy.streamingSource.get.topic
val topicInfo = TopicInfo.parse(topicUri)

val schemaProvider =
topicInfo.params.get(RegistryHostKey) match {
case Some(_) => new SchemaRegistrySchemaProvider(topicInfo.params)
case None =>
throw new IllegalArgumentException(
s"We only support schema registry based schema lookups. Missing $RegistryHostKey in topic config")
}

val (encoder, deserializationSchema) = schemaProvider.buildEncoderAndDeserSchema(topicInfo)
val source =
topicInfo.messageBus match {
case "kafka" =>
new KafkaFlinkSource(kafkaBootstrap, deserializationSchema, topicInfo)
case _ =>
throw new IllegalArgumentException(s"Unsupported message bus: ${topicInfo.messageBus}")
}

new FlinkJob(
eventSrc = source,
sinkFn = new AsyncKVStoreWriter(api, servingInfo.groupBy.metaData.name),
groupByServingInfoParsed = servingInfo,
encoder = encoder,
parallelism = source.parallelism
)

}
.recover { case e: Exception =>
throw new IllegalArgumentException(s"Unable to lookup serving info for GroupBy: '$groupByName'", e)
}
.get
}
maybeServingInfo
.map { servingInfo =>
buildFlinkJob(groupByName, kafkaBootstrap, api, servingInfo)
}
.recover { case e: Exception =>
throw new IllegalArgumentException(s"Unable to lookup serving info for GroupBy: '$groupByName'", e)
}
.get

val env = StreamExecutionEnvironment.getExecutionEnvironment
env.enableCheckpointing(CheckPointInterval.toMillis, CheckpointingMode.AT_LEAST_ONCE)
Expand Down Expand Up @@ -397,13 +294,7 @@ object FlinkJob {

env.configure(config)

val jobDatastream = if (api.flagStore.isSet(FlagStoreConstants.TILING_ENABLED, Map.empty[String, String].toJava)) {
flinkJob
.runTiledGroupByJob(env)
} else {
flinkJob
.runGroupByJob(env)
}
val jobDatastream = flinkJob.runTiledGroupByJob(env)

jobDatastream
.addSink(new MetricsSink(flinkJob.groupByName))
Expand All @@ -414,6 +305,46 @@ object FlinkJob {
env.execute(s"${flinkJob.groupByName}")
}

private def buildFlinkJob(groupByName: String,
kafkaBootstrap: Option[String],
api: Api,
servingInfo: GroupByServingInfoParsed) = {
val topicUri = servingInfo.groupBy.streamingSource.get.topic
val topicInfo = TopicInfo.parse(topicUri)

val schemaProvider =
topicInfo.params.get(RegistryHostKey) match {
case Some(_) => new ProjectedSchemaRegistrySchemaProvider(topicInfo.params)
case None =>
throw new IllegalArgumentException(
s"We only support schema registry based schema lookups. Missing $RegistryHostKey in topic config")
}

val deserializationSchema = schemaProvider.buildDeserializationSchema(servingInfo.groupBy)
require(
deserializationSchema.isInstanceOf[SourceProjection],
s"Expect created deserialization schema for groupBy: $groupByName with $topicInfo to mixin SourceProjection. " +
s"We got: ${deserializationSchema.getClass.getSimpleName}"
)
val projectedSchema = deserializationSchema.asInstanceOf[SourceProjection].projectedSchema

val source =
topicInfo.messageBus match {
case "kafka" =>
new ProjectedKafkaFlinkSource(kafkaBootstrap, deserializationSchema, topicInfo)
case _ =>
throw new IllegalArgumentException(s"Unsupported message bus: ${topicInfo.messageBus}")
}

new FlinkJob(
eventSrc = source,
projectedSchema,
sinkFn = new AsyncKVStoreWriter(api, servingInfo.groupBy.metaData.name),
groupByServingInfoParsed = servingInfo,
parallelism = source.parallelism
)
}

private def buildApi(onlineClass: String, props: Map[String, String]): Api = {
val cl = Thread.currentThread().getContextClassLoader // Use Flink's classloader
val cls = cl.loadClass(onlineClass)
Expand Down
1 change: 0 additions & 1 deletion flink/src/main/scala/ai/chronon/flink/FlinkSource.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package ai.chronon.flink
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment

// TODO deprecate this in favor of Api.readTopic + Api.streamDecoder
abstract class FlinkSource[T] extends Serializable {

/** Return a Flink DataStream for the given topic and groupBy.
Expand Down
23 changes: 16 additions & 7 deletions flink/src/main/scala/ai/chronon/flink/KafkaFlinkSource.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@ import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment
import org.apache.kafka.clients.consumer.OffsetResetStrategy
import org.apache.spark.sql.Row

class KafkaFlinkSource(kafkaBootstrap: Option[String],
deserializationSchema: DeserializationSchema[Row],
topicInfo: TopicInfo)
extends FlinkSource[Row] {

class BaseKafkaFlinkSource[T](kafkaBootstrap: Option[String],
deserializationSchema: DeserializationSchema[T],
topicInfo: TopicInfo)
extends FlinkSource[T] {
val bootstrap: String =
kafkaBootstrap.getOrElse(
topicInfo.params.getOrElse(
Expand All @@ -37,9 +36,9 @@ class KafkaFlinkSource(kafkaBootstrap: Option[String],
}

override def getDataStream(topic: String, groupByName: String)(env: StreamExecutionEnvironment,
parallelism: Int): SingleOutputStreamOperator[Row] = {
parallelism: Int): SingleOutputStreamOperator[T] = {
val kafkaSource = KafkaSource
.builder[Row]()
.builder[T]()
.setTopics(topicInfo.name)
.setGroupId(s"chronon-$groupByName")
// we might have a fairly large backlog to catch up on, so we choose to go with the latest offset when we're
Expand All @@ -55,3 +54,13 @@ class KafkaFlinkSource(kafkaBootstrap: Option[String],
.setParallelism(parallelism)
}
}

class KafkaFlinkSource(kafkaBootstrap: Option[String],
deserializationSchema: ChrononDeserializationSchema[Row],
topicInfo: TopicInfo)
extends BaseKafkaFlinkSource[Row](kafkaBootstrap, deserializationSchema, topicInfo)

class ProjectedKafkaFlinkSource(kafkaBootstrap: Option[String],
deserializationSchema: ChrononDeserializationSchema[Map[String, Any]],
topicInfo: TopicInfo)
extends BaseKafkaFlinkSource[Map[String, Any]](kafkaBootstrap, deserializationSchema, topicInfo)
Loading