diff --git a/build.sbt b/build.sbt index d51b25b53e..706b7e99b1 100644 --- a/build.sbt +++ b/build.sbt @@ -107,7 +107,6 @@ val circe = Seq( ).map(_ % circeVersion) val flink_all = Seq( - "org.apache.flink" %% "flink-streaming-scala", "org.apache.flink" % "flink-metrics-dropwizard", "org.apache.flink" % "flink-clients", "org.apache.flink" % "flink-yarn" @@ -214,6 +213,9 @@ lazy val flink = project .settings( libraryDependencies ++= spark_all, libraryDependencies ++= flink_all, + // mark the flink-streaming scala as provided as otherwise we end up with some extra Flink classes in our jar + // and errors at runtime like: java.io.InvalidClassException: org.apache.flink.streaming.api.scala.DataStream$$anon$1; local class incompatible + libraryDependencies += "org.apache.flink" %% "flink-streaming-scala" % flink_1_17 % "provided", assembly / assemblyMergeStrategy := { case PathList("META-INF", "services", xs @ _*) => MergeStrategy.concat case "reference.conf" => MergeStrategy.concat diff --git a/cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/DataprocSubmitter.scala b/cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/DataprocSubmitter.scala index aa8f0c35f5..1e846f503f 100644 --- a/cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/DataprocSubmitter.scala +++ b/cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/DataprocSubmitter.scala @@ -100,8 +100,35 @@ class DataprocSubmitter(jobControllerClient: JobControllerClient, conf: Submitte } private def buildFlinkJob(mainClass: String, mainJarUri: String, jarUri: String, args: String*): Job.Builder = { + + // TODO leverage a setting in teams.json when that's wired up + val checkpointsDir = "gs://zl-warehouse/flink-state" + + // JobManager is primarily responsible for coordinating the job (task slots, checkpoint triggering) and not much else + // so 4G should suffice. + // We go with 64G TM containers (4 task slots per container) + // Broadly Flink splits TM memory into: + // 1) Metaspace, framework offheap etc + // 2) Network buffers + // 3) Managed Memory (rocksdb) + // 4) JVM heap + // We tune down the network buffers to 1G-2G (default would be ~6.3G) and use some of the extra memory for + // managed mem + jvm heap + // Good doc - https://nightlies.apache.org/flink/flink-docs-master/docs/deployment/memory/mem_setup_tm val envProps = - Map("jobmanager.memory.process.size" -> "4G", "yarn.classpath.include-user-jar" -> "FIRST") + Map( + "jobmanager.memory.process.size" -> "4G", + "taskmanager.memory.process.size" -> "64G", + "taskmanager.memory.network.min" -> "1G", + "taskmanager.memory.network.max" -> "2G", + "taskmanager.memory.managed.fraction" -> "0.5f", + "yarn.classpath.include-user-jar" -> "FIRST", + "state.savepoints.dir" -> checkpointsDir, + "state.checkpoints.dir" -> checkpointsDir, + // override the local dir for rocksdb as the default ends up being too large file name size wise + "state.backend.rocksdb.localdir" -> "/tmp/flink-state", + "state.checkpoint-storage" -> "filesystem" + ) val flinkJob = FlinkJob .newBuilder() diff --git a/flink/src/main/scala/ai/chronon/flink/FlinkJob.scala b/flink/src/main/scala/ai/chronon/flink/FlinkJob.scala index a86bcae0c2..dd3f6f80dc 100644 --- a/flink/src/main/scala/ai/chronon/flink/FlinkJob.scala +++ b/flink/src/main/scala/ai/chronon/flink/FlinkJob.scala @@ -14,6 +14,10 @@ import ai.chronon.online.GroupByServingInfoParsed import ai.chronon.online.KVStore.PutRequest import ai.chronon.online.SparkConversions import org.apache.flink.api.scala._ +import org.apache.flink.configuration.CheckpointingOptions +import org.apache.flink.configuration.Configuration +import org.apache.flink.configuration.StateBackendOptions +import org.apache.flink.streaming.api.CheckpointingMode import org.apache.flink.streaming.api.functions.async.RichAsyncFunction import org.apache.flink.streaming.api.scala.DataStream import org.apache.flink.streaming.api.scala.OutputTag @@ -28,6 +32,9 @@ import org.rogach.scallop.ScallopOption import org.rogach.scallop.Serialization import org.slf4j.LoggerFactory +import scala.concurrent.duration.DurationInt +import scala.concurrent.duration.FiniteDuration + /** * Flink job that processes a single streaming GroupBy and writes out the results to the KV store. * @@ -82,9 +89,12 @@ class FlinkJob[T](eventSrc: FlinkSource[T], f"Running Flink job for groupByName=${groupByName}, Topic=${topic}. " + "Tiling is disabled.") + // we expect parallelism on the source stream to be set by the source provider val sourceStream: DataStream[T] = eventSrc .getDataStream(topic, groupByName)(env, parallelism) + .uid(s"source-$groupByName") + .name(s"Source for $groupByName") val sparkExprEvalDS: DataStream[Map[String, Any]] = sourceStream .flatMap(exprEval) @@ -128,9 +138,12 @@ class FlinkJob[T](eventSrc: FlinkSource[T], val tilingWindowSizeInMillis: Option[Long] = ResolutionUtils.getSmallestWindowResolutionInMillis(groupByServingInfoParsed.groupBy) + // we expect parallelism on the source stream to be set by the source provider val sourceStream: DataStream[T] = eventSrc .getDataStream(topic, groupByName)(env, parallelism) + .uid(s"source-$groupByName") + .name(s"Source for $groupByName") val sparkExprEvalDS: DataStream[Map[String, Any]] = sourceStream .flatMap(exprEval) @@ -202,6 +215,26 @@ class FlinkJob[T](eventSrc: FlinkSource[T], } object FlinkJob { + // we set an explicit max parallelism to ensure if we do make parallelism setting updates, there's still room + // to restore the job from prior state. Number chosen does have perf ramifications if too high (can impact rocksdb perf) + // so we've chosen one that should allow us to scale to jobs in the 10K-50K events / s range. + val MaxParallelism = 1260 // highly composite number + + // We choose to checkpoint frequently to ensure the incremental checkpoints are small in size + // as well as ensuring the catch-up backlog is fairly small in case of failures + val CheckPointInterval: FiniteDuration = 10.seconds + + // We set a more lenient checkpoint timeout to guard against large backlog / catchup scenarios where checkpoints + // might be slow and a tight timeout will set us on a snowball restart loop + val CheckpointTimeout: FiniteDuration = 5.minutes + + // We use incremental checkpoints and we cap how many we keep around + val MaxRetainedCheckpoints = 10 + + // how many consecutive checkpoint failures can we tolerate - default is 0, we choose a more lenient value + // to allow us a few tries before we give up + val TolerableCheckpointFailures = 5 + // Pull in the Serialization trait to sidestep: https://github.com/scallop/scallop/issues/137 class JobArgs(args: Seq[String]) extends ScallopConf(args) with Serialization { val onlineClass: ScallopOption[String] = @@ -235,13 +268,39 @@ object FlinkJob { // based on the topic type (e.g. kafka / pubsub) and the schema class name: // 1. lookup schema object using SchemaProvider (e.g SchemaRegistry / Jar based) // 2. Create the appropriate Encoder for the given schema type - // 3. Invoke the appropriate source provider to get the source, encoder, parallelism + // 3. Invoke the appropriate source provider to get the source, parallelism throw new IllegalArgumentException("We don't support non-mocked sources like Kafka / PubSub yet!") } val env = StreamExecutionEnvironment.getExecutionEnvironment - // TODO add useful configs - flinkJob.runGroupByJob(env).addSink(new PrintSink) // TODO wire up a metrics sink / such + + env.enableCheckpointing(CheckPointInterval.toMillis, CheckpointingMode.AT_LEAST_ONCE) + val checkpointConfig = env.getCheckpointConfig + checkpointConfig.setMinPauseBetweenCheckpoints(CheckPointInterval.toMillis) + checkpointConfig.setCheckpointTimeout(CheckpointTimeout.toMillis) + checkpointConfig.setMaxConcurrentCheckpoints(1) + checkpointConfig.setTolerableCheckpointFailureNumber(TolerableCheckpointFailures) + + val config = new Configuration() + + config.set(StateBackendOptions.STATE_BACKEND, "rocksdb") + config.setBoolean(CheckpointingOptions.INCREMENTAL_CHECKPOINTS, true) + config.setInteger(CheckpointingOptions.MAX_RETAINED_CHECKPOINTS, MaxRetainedCheckpoints) + + env.setMaxParallelism(MaxParallelism) + + env.getConfig.disableAutoGeneratedUIDs() // we generate UIDs manually to ensure consistency across runs + env.getConfig + .enableForceKryo() // use kryo for complex types that Flink's default ser system doesn't support (e.g case classes) + env.getConfig.enableGenericTypes() // more permissive type checks + + env.configure(config) + + flinkJob + .runGroupByJob(env) + .addSink(new MetricsSink(flinkJob.groupByName)) + .uid(s"metrics-sink - ${flinkJob.groupByName}") + .name(s"Metrics Sink for ${flinkJob.groupByName}") env.execute(s"${flinkJob.groupByName}") } diff --git a/flink/src/main/scala/ai/chronon/flink/MetricsSink.scala b/flink/src/main/scala/ai/chronon/flink/MetricsSink.scala new file mode 100644 index 0000000000..016f89a2bd --- /dev/null +++ b/flink/src/main/scala/ai/chronon/flink/MetricsSink.scala @@ -0,0 +1,36 @@ +package ai.chronon.flink +import com.codahale.metrics.ExponentiallyDecayingReservoir +import org.apache.flink.configuration.Configuration +import org.apache.flink.dropwizard.metrics.DropwizardHistogramWrapper +import org.apache.flink.metrics.Histogram +import org.apache.flink.streaming.api.functions.sink.RichSinkFunction +import org.apache.flink.streaming.api.functions.sink.SinkFunction + +/** + * Sink that captures metrics around feature freshness. We capture the time taken from event creation to KV store sink + * Ideally we expect this to match the Kafka persistence -> sink time. They can diverge if the event object is created and held on + * in the source service for some time before the event is submitted to Kafka. + */ +class MetricsSink(groupByName: String) extends RichSinkFunction[WriteResponse] { + + @transient private var eventCreatedToSinkTimeHistogram: Histogram = _ + + override def open(parameters: Configuration): Unit = { + super.open(parameters) + val metricsGroup = getRuntimeContext.getMetricGroup + .addGroup("chronon") + .addGroup("feature_group", groupByName) + + eventCreatedToSinkTimeHistogram = metricsGroup.histogram( + "event_created_to_sink_time", + new DropwizardHistogramWrapper( + new com.codahale.metrics.Histogram(new ExponentiallyDecayingReservoir()) + ) + ) + } + + override def invoke(value: WriteResponse, context: SinkFunction.Context): Unit = { + val eventCreatedToSinkTime = System.currentTimeMillis() - value.putRequest.tsMillis.get + eventCreatedToSinkTimeHistogram.update(eventCreatedToSinkTime) + } +} diff --git a/flink/src/main/scala/ai/chronon/flink/TestFlinkJob.scala b/flink/src/main/scala/ai/chronon/flink/TestFlinkJob.scala index 83d1cde90f..4e05413899 100644 --- a/flink/src/main/scala/ai/chronon/flink/TestFlinkJob.scala +++ b/flink/src/main/scala/ai/chronon/flink/TestFlinkJob.scala @@ -15,6 +15,7 @@ import ai.chronon.online.Extensions.StructTypeOps import ai.chronon.online.GroupByServingInfoParsed import org.apache.flink.api.scala.createTypeInformation import org.apache.flink.streaming.api.functions.sink.SinkFunction +import org.apache.flink.streaming.api.functions.source.SourceFunction import org.apache.flink.streaming.api.scala.DataStream import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment import org.apache.spark.sql.Encoder @@ -37,7 +38,25 @@ class E2EEventSource(mockEvents: Seq[E2ETestEvent]) extends FlinkSource[E2ETestE override def getDataStream(topic: String, groupName: String)(env: StreamExecutionEnvironment, parallelism: Int): DataStream[E2ETestEvent] = { - env.fromCollection(mockEvents) + env + .addSource(new SourceFunction[E2ETestEvent] { + private var isRunning = true + + override def run(ctx: SourceFunction.SourceContext[E2ETestEvent]): Unit = { + while (isRunning) { + mockEvents.foreach { event => + ctx.collect(event) + } + // Add some delay between event batches + Thread.sleep(1000) + } + } + + override def cancel(): Unit = { + isRunning = false + } + }) + .setParallelism(1) } }