Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package ai.chronon.spark.stats
package ai.chronon.aggregator.stats

object EditDistance {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
* limitations under the License.
*/

package ai.chronon.spark.test
package ai.chronon.aggregator.test

import ai.chronon.spark.stats.EditDistance
import ai.chronon.aggregator.stats.EditDistance
import org.junit.Assert.assertEquals
import org.scalatest.flatspec.AnyFlatSpec

Expand Down
3 changes: 3 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -217,12 +217,14 @@ lazy val spark = project
lazy val flink = project
.dependsOn(aggregator.%("compile->compile;test->test"), online.%("compile->compile;test->test"))
.settings(
resolvers += "Confluent" at "https://packages.confluent.io/maven/", // needed for confluent's schema registry
libraryDependencies ++= spark_all,
libraryDependencies ++= flink_all,
// mark the flink-streaming scala as provided as otherwise we end up with some extra Flink classes in our jar
// and errors at runtime like: java.io.InvalidClassException: org.apache.flink.streaming.api.scala.DataStream$$anon$1; local class incompatible
libraryDependencies += "org.apache.flink" %% "flink-streaming-scala" % flink_1_17 % "provided",
libraryDependencies += "org.apache.flink" % "flink-connector-files" % flink_1_17 % "provided",
libraryDependencies += "io.confluent" % "kafka-schema-registry-client" % "7.6.0",
libraryDependencies += "org.apache.spark" %% "spark-avro" % spark_3_5,
assembly / assemblyMergeStrategy := {
case PathList("META-INF", "services", xs @ _*) => MergeStrategy.concat
Expand All @@ -246,6 +248,7 @@ lazy val flink = project
assembly / packageOptions += Package.ManifestAttributes(
("Main-Class", "ai.chronon.flink.FlinkJob")
),
libraryDependencies += "io.confluent" % "kafka-protobuf-provider" % "7.6.0" % Test,
libraryDependencies += "org.apache.flink" % "flink-test-utils" % flink_1_17 % Test excludeAll (
ExclusionRule(organization = "org.apache.logging.log4j", name = "log4j-api"),
ExclusionRule(organization = "org.apache.logging.log4j", name = "log4j-core"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,10 @@ class DataprocSubmitterTest extends AnyFlatSpec with MockitoSugar {
JarURI -> "gs://zipline-jars/cloud_gcp_bigtable.jar"),
List.empty,
"--online-class=ai.chronon.integrations.cloud_gcp.GcpApiImpl",
"--groupby-name=e2e-count",
"-ZGCP_PROJECT_ID=bigtable-project-id",
"-ZGCP_BIGTABLE_INSTANCE_ID=bigtable-instance-id")
"--groupby-name=etsy.listing_canary.actions_v1",
"-kafka-bootstrap=bootstrap.zipline-kafka-cluster.us-central1.managedkafka.canary-443022.cloud.goog:9092",
"-ZGCP_PROJECT_ID=canary-443022",
"-ZGCP_BIGTABLE_INSTANCE_ID=zipline-canary-instance")
}

it should "test flink kafka ingest job locally" ignore {
Expand Down
106 changes: 92 additions & 14 deletions flink/src/main/scala/ai/chronon/flink/FlinkJob.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package ai.chronon.flink

import ai.chronon.aggregator.windowing.ResolutionUtils
import ai.chronon.api.Constants
import ai.chronon.api.Constants.MetadataDataset
import ai.chronon.api.DataType
import ai.chronon.api.Extensions.GroupByOps
import ai.chronon.api.Extensions.SourceOps
import ai.chronon.flink.FlinkJob.watermarkStrategy
import ai.chronon.flink.window.AlwaysFireOnElementTrigger
import ai.chronon.flink.window.FlinkRowAggProcessFunction
import ai.chronon.flink.window.FlinkRowAggregationFunction
Expand All @@ -12,7 +15,11 @@ import ai.chronon.flink.window.TimestampedTile
import ai.chronon.online.Api
import ai.chronon.online.GroupByServingInfoParsed
import ai.chronon.online.KVStore.PutRequest
import ai.chronon.online.MetadataStore
import ai.chronon.online.SparkConversions
import ai.chronon.online.TopicInfo
import org.apache.flink.api.common.eventtime.SerializableTimestampAssigner
import org.apache.flink.api.common.eventtime.WatermarkStrategy
import org.apache.flink.api.scala._
import org.apache.flink.configuration.CheckpointingOptions
import org.apache.flink.configuration.Configuration
Expand All @@ -33,6 +40,7 @@ import org.rogach.scallop.ScallopOption
import org.rogach.scallop.Serialization
import org.slf4j.LoggerFactory

import java.time.Duration
import scala.concurrent.duration.DurationInt
import scala.concurrent.duration.FiniteDuration

Expand Down Expand Up @@ -103,7 +111,13 @@ class FlinkJob[T](eventSrc: FlinkSource[T],
.name(s"Spark expression eval for $groupByName")
.setParallelism(sourceStream.parallelism) // Use same parallelism as previous operator

val putRecordDS: DataStream[PutRequest] = sparkExprEvalDS
val sparkExprEvalDSWithWatermarks: DataStream[Map[String, Any]] = sparkExprEvalDS
.assignTimestampsAndWatermarks(watermarkStrategy)
.uid(s"spark-expr-eval-timestamps-$groupByName")
.name(s"Spark expression eval with timestamps for $groupByName")
.setParallelism(sourceStream.parallelism)

val putRecordDS: DataStream[PutRequest] = sparkExprEvalDSWithWatermarks
.flatMap(AvroCodecFn[T](groupByServingInfoParsed))
.uid(s"avro-conversion-$groupByName")
.name(s"Avro conversion for $groupByName")
Expand Down Expand Up @@ -152,6 +166,12 @@ class FlinkJob[T](eventSrc: FlinkSource[T],
.name(s"Spark expression eval for $groupByName")
.setParallelism(sourceStream.parallelism) // Use same parallelism as previous operator

val sparkExprEvalDSAndWatermarks: DataStream[Map[String, Any]] = sparkExprEvalDS
.assignTimestampsAndWatermarks(watermarkStrategy)
.uid(s"spark-expr-eval-timestamps-$groupByName")
.name(s"Spark expression eval with timestamps for $groupByName")
.setParallelism(sourceStream.parallelism)

val inputSchema: Seq[(String, DataType)] =
exprEval.getOutputSchema.fields
.map(field => (field.name, SparkConversions.toChrononType(field.name, field.dataType)))
Expand Down Expand Up @@ -179,7 +199,7 @@ class FlinkJob[T](eventSrc: FlinkSource[T],
// - The only purpose of this window function is to mark tiles as closed so we can do client-side caching in SFS
// 7. Output: TimestampedTile, containing the current IRs (Avro encoded) and the timestamp of the current element
val tilingDS: DataStream[TimestampedTile] =
sparkExprEvalDS
sparkExprEvalDSAndWatermarks
.keyBy(KeySelector.getKeySelectionFunction(groupByServingInfoParsed.groupBy))
.window(window)
.trigger(trigger)
Expand Down Expand Up @@ -236,6 +256,25 @@ object FlinkJob {
// to allow us a few tries before we give up
val TolerableCheckpointFailures = 5

// Keep windows open for a bit longer before closing to ensure we don't lose data due to late arrivals (needed in case of
// tiling implementation)
val AllowedOutOfOrderness: Duration = Duration.ofMinutes(5)

// Set an idleness timeout to keep time moving in case of very low traffic event streams as well as late events during
// large backlog catchups
val IdlenessTimeout: Duration = Duration.ofSeconds(30)

// We wire up the watermark strategy post the spark expr eval to be able to leverage the user's timestamp column (which is
// ETLed to Contants.TimeColumn) as the event timestamp and watermark
val watermarkStrategy: WatermarkStrategy[Map[String, Any]] = WatermarkStrategy
Copy link
Contributor

@nikhil-zlai nikhil-zlai Jan 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wonder if we should do Array[Any] & schema separately - instead of Map[String, Any]. That we can do when we need to optimize later - since it is backwards compatible I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we plug the watermark strategy downstream of the sparkexpr eval so that we can pick up the timestamp as the user defined it in their groupby source. Our spark eval output is map[string, any] which is what the type is here (we just reach in and pull out the 'ts' column). One thing to call out - the Map[String, Any] type pretty much stays for one operator hop. Immediately after the spark + watermark op we have a Avro codec operator which converts this to the appropriate event avro ready for wrting to the kv..

.forBoundedOutOfOrderness[Map[String, Any]](AllowedOutOfOrderness)
.withIdleness(IdlenessTimeout)
.withTimestampAssigner(new SerializableTimestampAssigner[Map[String, Any]] {
override def extractTimestamp(element: Map[String, Any], recordTimestamp: Long): Long = {
element.get(Constants.TimeColumn).map(_.asInstanceOf[Long]).getOrElse(recordTimestamp)
}
})

Comment on lines +270 to +277
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very cool!

// Pull in the Serialization trait to sidestep: https://github.com/scallop/scallop/issues/137
class JobArgs(args: Seq[String]) extends ScallopConf(args) with Serialization {
val onlineClass: ScallopOption[String] =
Expand All @@ -244,7 +283,10 @@ object FlinkJob {
val groupbyName: ScallopOption[String] =
opt[String](required = true, descr = "The name of the groupBy to process")
val mockSource: ScallopOption[Boolean] =
opt[Boolean](required = false, descr = "Use a mocked data source instead of a real source", default = Some(true))
opt[Boolean](required = false, descr = "Use a mocked data source instead of a real source", default = Some(false))
// Kafka config is optional as we can support other sources in the future
val kafkaBootstrap: ScallopOption[String] =
opt[String](required = false, descr = "Kafka bootstrap server in host:port format")

val apiProps: Map[String, String] = props[String]('Z', descr = "Props to configure API / KV Store")

Expand All @@ -253,28 +295,60 @@ object FlinkJob {

def main(args: Array[String]): Unit = {
val jobArgs = new JobArgs(args)
jobArgs.groupbyName()
val groupByName = jobArgs.groupbyName()
val onlineClassName = jobArgs.onlineClass()
val props = jobArgs.apiProps.map(identity)
val useMockedSource = jobArgs.mockSource()
val kafkaBootstrap = jobArgs.kafkaBootstrap.toOption

val api = buildApi(onlineClassName, props)
val metadataStore = new MetadataStore(api.genKvStore, MetadataDataset, timeoutMillis = 10000)

val flinkJob =
if (useMockedSource) {
// We will yank this conditional block when we wire up our real sources etc.
TestFlinkJob.buildTestFlinkJob(api)
} else {
// TODO - what we need to do when we wire this up for real
// lookup groupByServingInfo by groupByName from the kv store
// based on the topic type (e.g. kafka / pubsub) and the schema class name:
// 1. lookup schema object using SchemaProvider (e.g SchemaRegistry / Jar based)
// 2. Create the appropriate Encoder for the given schema type
// 3. Invoke the appropriate source provider to get the source, parallelism
throw new IllegalArgumentException("We don't support non-mocked sources like Kafka / PubSub yet!")
val maybeServingInfo = metadataStore.getGroupByServingInfo(groupByName)
maybeServingInfo
.map { servingInfo =>
val topicUri = servingInfo.groupBy.streamingSource.get.topic
val topicInfo = TopicInfo.parse(topicUri)

val schemaProvider =
topicInfo.params.get("registry_url") match {
case Some(_) => new SchemaRegistrySchemaProvider(topicInfo.params)
case None =>
throw new IllegalArgumentException(
"We only support schema registry based schema lookups. Missing registry_url in topic config")
}

val (encoder, deserializationSchema) = schemaProvider.buildEncoderAndDeserSchema(topicInfo)
val source =
topicInfo.messageBus match {
case "kafka" =>
new KafkaFlinkSource(kafkaBootstrap, deserializationSchema, topicInfo)
case _ =>
throw new IllegalArgumentException(s"Unsupported message bus: ${topicInfo.messageBus}")
}

new FlinkJob(
eventSrc = source,
sinkFn = new AsyncKVStoreWriter(api, servingInfo.groupBy.metaData.name),
groupByServingInfoParsed = servingInfo,
encoder = encoder,
parallelism = source.parallelism
)

}
.recover {
case e: Exception =>
throw new IllegalArgumentException(s"Unable to lookup serving info for GroupBy: '$groupByName'", e)
}
.get
}

val env = StreamExecutionEnvironment.getExecutionEnvironment

env.enableCheckpointing(CheckPointInterval.toMillis, CheckpointingMode.AT_LEAST_ONCE)
val checkpointConfig = env.getCheckpointConfig
checkpointConfig.setMinPauseBetweenCheckpoints(CheckPointInterval.toMillis)
Expand All @@ -300,15 +374,19 @@ object FlinkJob {

env.configure(config)

flinkJob
val jobDatastream = flinkJob
.runGroupByJob(env)

jobDatastream
.addSink(new MetricsSink(flinkJob.groupByName))
.uid(s"metrics-sink - ${flinkJob.groupByName}")
.name(s"Metrics Sink for ${flinkJob.groupByName}")
.setParallelism(jobDatastream.parallelism)

env.execute(s"${flinkJob.groupByName}")
}

def buildApi(onlineClass: String, props: Map[String, String]): Api = {
private def buildApi(onlineClass: String, props: Map[String, String]): Api = {
val cl = Thread.currentThread().getContextClassLoader // Use Flink's classloader
val cls = cl.loadClass(onlineClass)
val constructor = cls.getConstructors.apply(0)
Expand Down
53 changes: 53 additions & 0 deletions flink/src/main/scala/ai/chronon/flink/KafkaFlinkSource.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package ai.chronon.flink

import ai.chronon.online.TopicChecker
import ai.chronon.online.TopicInfo
import org.apache.flink.api.common.eventtime.WatermarkStrategy
import org.apache.flink.api.common.serialization.DeserializationSchema
import org.apache.flink.api.scala.createTypeInformation
import org.apache.flink.connector.kafka.source.KafkaSource
import org.apache.flink.connector.kafka.source.enumerator.initializer.OffsetsInitializer
import org.apache.flink.streaming.api.scala.DataStream
import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
import org.apache.kafka.clients.consumer.OffsetResetStrategy
import org.apache.spark.sql.Row

class KafkaFlinkSource(kafkaBootstrap: Option[String],
deserializationSchema: DeserializationSchema[Row],
topicInfo: TopicInfo)
extends FlinkSource[Row] {

val bootstrap: String =
kafkaBootstrap.getOrElse(
topicInfo.params.getOrElse(
"bootstrap",
topicInfo.params("host") + topicInfo.params
.get("port")
.map(":" + _)
.getOrElse(throw new IllegalArgumentException("No bootstrap servers provided"))
))

// confirm the topic exists
TopicChecker.topicShouldExist(topicInfo.name, bootstrap, topicInfo.params)

Comment on lines +30 to +32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Close AdminClient to avoid leaks.
Topic existence is checked, but no final close is called on AdminClient.

+ try {
+   TopicChecker.topicShouldExist(topicInfo.name, bootstrap, topicInfo.params)
+ } finally {
+   // ensure resources are closed
+ }

Committable suggestion skipped: line range outside the PR's diff.

implicit val parallelism: Int = TopicChecker.getPartitions(topicInfo.name, bootstrap, topicInfo.params)

override def getDataStream(topic: String, groupByName: String)(env: StreamExecutionEnvironment,
parallelism: Int): DataStream[Row] = {
val kafkaSource = KafkaSource
.builder[Row]()
.setTopics(topicInfo.name)
.setGroupId(s"chronon-$groupByName")
// we might have a fairly large backlog to catch up on, so we choose to go with the latest offset when we're
// starting afresh
.setStartingOffsets(OffsetsInitializer.committedOffsets(OffsetResetStrategy.LATEST))
.setValueOnlyDeserializer(deserializationSchema)
.setBootstrapServers(bootstrap)
.setProperties(TopicChecker.mapToJavaProperties(topicInfo.params))
.build()

env
.fromSource(kafkaSource, WatermarkStrategy.noWatermarks(), s"Kafka source: $groupByName - ${topicInfo.name}")
.setParallelism(parallelism)
}
}
17 changes: 17 additions & 0 deletions flink/src/main/scala/ai/chronon/flink/SchemaProvider.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package ai.chronon.flink

import ai.chronon.online.TopicInfo
import org.apache.flink.api.common.serialization.DeserializationSchema
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.Row

/**
* A SchemaProvider is responsible for providing the Encoder and DeserializationSchema for a given topic.
* This class handles looking up the schema and then based on the schema type (e.g. Avro, Protobuf) it will create
* the appropriate Encoder and DeserializationSchema. The Encoder is needed for SparkExpressionEval and the DeserializationSchema
* is needed to allow Flink's Kafka / other sources to crack open the Array[Byte] payloads.
* @param conf - Configuration for the SchemaProvider (we pick this up from the topicInfo param map)
*/
abstract class SchemaProvider(conf: Map[String, String]) {
def buildEncoderAndDeserSchema(topicInfo: TopicInfo): (Encoder[Row], DeserializationSchema[Row])
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package ai.chronon.flink
import ai.chronon.online.TopicInfo
import io.confluent.kafka.schemaregistry.avro.AvroSchema
import io.confluent.kafka.schemaregistry.client.CachedSchemaRegistryClient
import io.confluent.kafka.schemaregistry.client.SchemaRegistryClient
import io.confluent.kafka.schemaregistry.client.rest.exceptions.RestClientException
import org.apache.flink.api.common.serialization.DeserializationSchema
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.Row
import org.apache.spark.sql.avro.AvroDeserializationSupport

class SchemaRegistrySchemaProvider(conf: Map[String, String]) extends SchemaProvider(conf) {

private val schemaRegistryUrl: String =
conf.getOrElse("registry_url", throw new IllegalArgumentException("registry_url not set"))
private val CacheCapacity: Int = 10

private val schemaRegistryClient: SchemaRegistryClient = buildSchemaRegistryClient(schemaRegistryUrl)

private[flink] def buildSchemaRegistryClient(registryUrl: String): SchemaRegistryClient =
new CachedSchemaRegistryClient(registryUrl, CacheCapacity)
Comment on lines +18 to +21
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add error handling for client initialization.

Catch and wrap initialization errors, consider adding retry mechanism.

-  private val schemaRegistryClient: SchemaRegistryClient = buildSchemaRegistryClient(schemaRegistryUrl)
+  private val schemaRegistryClient: SchemaRegistryClient = try {
+    buildSchemaRegistryClient(schemaRegistryUrl)
+  } catch {
+    case e: Exception => throw new IllegalStateException(s"Failed to initialize schema registry client: ${e.getMessage}", e)
+  }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
private val schemaRegistryClient: SchemaRegistryClient = buildSchemaRegistryClient(schemaRegistryUrl)
private[flink] def buildSchemaRegistryClient(registryUrl: String): SchemaRegistryClient =
new CachedSchemaRegistryClient(registryUrl, CacheCapacity)
private val schemaRegistryClient: SchemaRegistryClient = try {
buildSchemaRegistryClient(schemaRegistryUrl)
} catch {
case e: Exception => throw new IllegalStateException(s"Failed to initialize schema registry client: ${e.getMessage}", e)
}
private[flink] def buildSchemaRegistryClient(registryUrl: String): SchemaRegistryClient =
new CachedSchemaRegistryClient(registryUrl, CacheCapacity)


override def buildEncoderAndDeserSchema(topicInfo: TopicInfo): (Encoder[Row], DeserializationSchema[Row]) = {
val subject = topicInfo.params.getOrElse("subject", s"${topicInfo.name}-value")
val parsedSchema =
try {
val metadata = schemaRegistryClient.getLatestSchemaMetadata(subject)
schemaRegistryClient.getSchemaById(metadata.getId)
} catch {
case e: RestClientException =>
throw new IllegalArgumentException(
s"Failed to retrieve schema details from the registry. Status: ${e.getStatus}; Error code: ${e.getErrorCode}",
e)
case e: Exception =>
throw new IllegalArgumentException("Error connecting to and requesting schema details from the registry", e)
}
// we currently only support Avro encoders
parsedSchema.schemaType() match {
case AvroSchema.TYPE =>
val schema = parsedSchema.asInstanceOf[AvroSchema]
AvroDeserializationSupport.build(topicInfo.name, schema.canonicalString(), schemaRegistryWireFormat = true)
case _ => throw new IllegalArgumentException(s"Unsupported schema type: ${parsedSchema.schemaType()}")
}
}
}
23 changes: 0 additions & 23 deletions flink/src/main/scala/ai/chronon/flink/SourceProvider.scala

This file was deleted.

Loading