Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions api/python/ai/chronon/repo/default_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self, args, jar_path):
self.custom_savepoint = args.get("custom_savepoint")
self.no_savepoint = args.get("no_savepoint")
self.version_check = args.get("version_check")
self.additional_jars = args.get("additional_jars")

flink_state_uri = args.get("flink_state_uri")
if flink_state_uri:
Expand Down
6 changes: 5 additions & 1 deletion api/python/ai/chronon/repo/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def run_dataproc_flink_streaming(self):
elif "deploy" in args:
user_args["--streaming-mode"] = "deploy"

flag_args = {"--mock-source": self.mock_source, "--validate": self.validate}
flag_args = {"--validate": self.validate}

# Set the savepoint deploy strategy
if self.latest_savepoint:
Expand All @@ -356,6 +356,10 @@ def run_dataproc_flink_streaming(self):
if self.version_check:
flag_args["--version-check"] = self.version_check

# Set additional jars
if self.additional_jars:
user_args["--additional-jars"] = self.additional_jars

user_args_str = " ".join(f"{key}={value}" for key, value in user_args.items() if value)
flag_args_str = " ".join(key for key, value in flag_args.items() if value)
dataproc_args = self.generate_dataproc_submitter_args(
Expand Down
22 changes: 15 additions & 7 deletions api/python/ai/chronon/repo/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,16 @@ def validate_flink_state(ctx, param, value):
)
return value

def validate_additional_jars(ctx, param, value):
if value:
jars = value.split(',')
for jar in jars:
if not jar.startswith(('gs://', 's3://')):
raise click.BadParameter(
f"Additional jars must start with gs://, s3://: {jar}"
)
return value

@click.command(
name="run",
context_settings=dict(allow_extra_args=True, ignore_unknown_options=True),
Expand Down Expand Up @@ -173,11 +183,9 @@ def validate_flink_state(ctx, param, value):
@click.option("--flink-state-uri",
help="Bucket for storing flink state checkpoints/savepoints and other internal pieces for orchestration.",
callback=validate_flink_state)
@click.option(
"--mock-source",
is_flag=True,
help="Use a mocked data source instead of a real source for groupby-streaming Flink.",
)
@click.option("--additional-jars",
help="Comma separated list of additional jar URIs to be included in the Flink job classpath (e.g. gs://bucket/jar1.jar,gs://bucket/jar2.jar).",
callback=validate_additional_jars)
@click.option(
"--validate",
is_flag=True,
Expand Down Expand Up @@ -224,12 +232,12 @@ def main(
no_savepoint,
version_check,
flink_state_uri,
mock_source,
validate,
validate_rows,
join_part_name,
artifact_prefix,
disable_cloud_logging
disable_cloud_logging,
additional_jars,
):
unknown_args = ctx.args
click.echo("Running with args: {}".format(ctx.params))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,14 @@ class DataprocSubmitter(jobControllerClient: JobControllerClient,
throw new RuntimeException(s"Missing expected $FlinkCheckpointUri"))
val maybeSavepointUri = submissionProperties.get(SavepointUri)
val maybePubSubConnectorJarUri = submissionProperties.get(FlinkPubSubConnectorJarURI)

val maybeAdditionalJarsUri = submissionProperties.get(AdditionalJars)
val additionalJars = maybeAdditionalJarsUri.map(_.split(",")).getOrElse(Array.empty)
val jarUris = Array(jarUri) ++ maybePubSubConnectorJarUri.toList ++ additionalJars
buildFlinkJob(mainClass,
mainJarUri,
jarUri,
jarUris,
flinkCheckpointPath,
maybeSavepointUri,
maybePubSubConnectorJarUri,
jobProperties,
(args :+ "--parent-job-id" :+ jobId): _*)
}
Expand Down Expand Up @@ -238,10 +239,9 @@ class DataprocSubmitter(jobControllerClient: JobControllerClient,

private[cloud_gcp] def buildFlinkJob(mainClass: String,
mainJarUri: String,
jarUri: String,
jarUris: Array[String],
flinkCheckpointUri: String,
maybeSavePointUri: Option[String],
maybePubSubConnectorJarUri: Option[String],
jobProperties: Map[String, String],
args: String*): Job.Builder = {

Expand Down Expand Up @@ -290,7 +290,6 @@ class DataprocSubmitter(jobControllerClient: JobControllerClient,
"state.checkpoints.num-retained" -> MaxRetainedCheckpoints
)

val jarUris = Array(jarUri) ++ maybePubSubConnectorJarUri.toList
val flinkJobBuilder = FlinkJob
.newBuilder()
.setMainClass(mainClass)
Expand Down Expand Up @@ -402,6 +401,8 @@ object DataprocSubmitter {
// pull the pubsub connector uri if it has been passed
val maybePubSubJarUri = JobSubmitter
.getArgValue(args, FlinkPubSubJarUriArgKeyword)
// include additional jars if present
val additionalJars = JobSubmitter.getArgValue(args, AdditionalJarsUriArgKeyword)

val baseJobProps = Map(
MainClass -> mainClass,
Expand All @@ -410,7 +411,7 @@ object DataprocSubmitter {
FlinkCheckpointUri -> flinkCheckpointUri,
MetadataName -> metadataName,
JobId -> jobId
) ++ maybePubSubJarUri.map(FlinkPubSubConnectorJarURI -> _)
) ++ (maybePubSubJarUri.map(FlinkPubSubConnectorJarURI -> _) ++ additionalJars.map(AdditionalJars -> _))

val groupByName = JobSubmitter
.getArgValue(args, GroupByNameArgKeyword)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,10 @@ class DataprocSubmitterTest extends AnyFlatSpec with MockitoSugar {
conf = SubmitterConf("test-project", "test-region", "test-cluster"))
val job = submitter.buildFlinkJob(
mainClass = "ai.chronon.flink.FlinkJob",
jarUri = "gs://zipline-jars/cloud-gcp.jar",
jarUris = Array("gs://zipline-jars/cloud-gcp.jar"),
mainJarUri = "gs://zipline-jars/flink-assembly-0.1.0-SNAPSHOT.jar",
flinkCheckpointUri = "gs://zl-warehouse/flink-state",
maybeSavePointUri = Option("gs://zipline-warehouse/flink-state/groupby-name/chk-1"),
maybePubSubConnectorJarUri = None,
jobProperties = Map("key" -> "value"),
args = List("args1", "args2"): _*
)
Expand Down Expand Up @@ -97,11 +96,10 @@ class DataprocSubmitterTest extends AnyFlatSpec with MockitoSugar {
conf = SubmitterConf("test-project", "test-region", "test-cluster"))
val job = submitter.buildFlinkJob(
mainClass = "ai.chronon.flink.FlinkJob",
jarUri = "gs://zipline-jars/cloud-gcp.jar",
jarUris = Array("gs://zipline-jars/cloud-gcp.jar"),
mainJarUri = "gs://zipline-jars/flink-assembly-0.1.0-SNAPSHOT.jar",
flinkCheckpointUri = "gs://zl-warehouse/flink-state",
maybeSavePointUri = None,
maybePubSubConnectorJarUri = None,
jobProperties = Map("key" -> "value"),
args = List("args1", "args2"): _*
)
Expand All @@ -115,11 +113,10 @@ class DataprocSubmitterTest extends AnyFlatSpec with MockitoSugar {
conf = SubmitterConf("test-project", "test-region", "test-cluster"))
val job = submitter.buildFlinkJob(
mainClass = "ai.chronon.flink.FlinkJob",
jarUri = "gs://zipline-jars/cloud-gcp.jar",
jarUris = Array("gs://zipline-jars/cloud-gcp.jar", "gs://zipline-jars/flink-pubsub-connector.jar"),
mainJarUri = "gs://zipline-jars/flink-assembly-0.1.0-SNAPSHOT.jar",
flinkCheckpointUri = "gs://zl-warehouse/flink-state",
maybeSavePointUri = None,
maybePubSubConnectorJarUri = Some("gs://zipline-jars/flink-pubsub-connector.jar"),
jobProperties = Map("key" -> "value"),
args = List("args1", "args2"): _*
)
Expand Down Expand Up @@ -300,6 +297,53 @@ class DataprocSubmitterTest extends AnyFlatSpec with MockitoSugar {

assertEquals(actual(SavepointUri), userPassedSavepoint)
}

it should "test createSubmissionPropsMap for flink job with additional jars" in {
val confPath = "chronon/cloud_gcp/src/test/resources/group_bys/team/purchases.v1"
val runfilesDir = Option(System.getenv("RUNFILES_DIR")).getOrElse(".")
val path = Paths.get(runfilesDir, confPath)

val manifestBucketPath = "gs://zipline-warehouse/flink-manifest"
val groupByName = "quickstart.purchases.v1"
val flinkCheckpointUri = "gs://zl-warehouse/flink-state/checkpoints"
val ziplineVersion = "0.1.0"
val mainClass = "ai.chronon.flink.FlinkJob"
val jarURI = "gs://zipline-jars/cloud-gcp.jar"
val additionalJars = "gs://zipline-jars/some-jar.jar,gs://zipline-jars/another-jar.jar"
val flinkMainJarURI = "gs://zipline-jars/flink-assembly-0.1.0-SNAPSHOT.jar"
val pubSubConnectorJarURI = "gs://zipline-jars/flink-pubsub-connector.jar"
val userPassedSavepoint = "gs://zl-warehouse/flink-state/1234/chk-12"

val submitter = mock[DataprocSubmitter]

val actual = DataprocSubmitter.createSubmissionPropsMap(
jobType = submission.FlinkJob,
submitter = submitter,
args = Array(
s"$JarUriArgKeyword=$jarURI",
s"$AdditionalJarsUriArgKeyword=$additionalJars",
s"$MainClassKeyword=$mainClass",
s"$LocalConfPathArgKeyword=${path.toAbsolutePath.toString}",
s"$ConfTypeArgKeyword=group_bys",
s"$OriginalModeArgKeyword=streaming",
s"$ZiplineVersionArgKeyword=$ziplineVersion",
s"$FlinkMainJarUriArgKeyword=$flinkMainJarURI",
s"$FlinkPubSubJarUriArgKeyword=$pubSubConnectorJarURI",
s"$GroupByNameArgKeyword=$groupByName",
s"$StreamingManifestPathArgKeyword=$manifestBucketPath",
s"$StreamingCustomSavepointArgKeyword=$userPassedSavepoint",
s"$StreamingCheckpointPathArgKeyword=$flinkCheckpointUri",
s"$JobIdArgKeyword=job-id"
)
)

assertEquals(actual(MainClass), mainClass)
assertEquals(actual(JarURI), jarURI)
assertEquals(actual(FlinkMainJarURI), flinkMainJarURI)
assertEquals(actual(AdditionalJars), additionalJars)
assertEquals(actual(FlinkPubSubConnectorJarURI), pubSubConnectorJarURI)
}

it should "test getDataprocFilesArgs when empty" in {
val actual = DataprocSubmitter.getDataprocFilesArgs()
assert(actual.isEmpty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ object JobSubmitterConstants {
val JarURI = "jarUri"
val FlinkMainJarURI = "flinkMainJarUri"
val FlinkPubSubConnectorJarURI = "flinkPubSubConnectorJarUri"
val AdditionalJars = "additionalJars"
val SavepointUri = "savepointUri"
val FlinkStateUri = "flinkStateUri"
val FlinkCheckpointUri = "flinkCheckpointUri"
Expand Down Expand Up @@ -168,6 +169,7 @@ object JobSubmitterConstants {
val MainClassKeyword = "--main-class"
val FlinkMainJarUriArgKeyword = "--flink-main-jar-uri"
val FlinkPubSubJarUriArgKeyword = "--flink-pubsub-jar-uri"
val AdditionalJarsUriArgKeyword = "--additional-jars"
val FilesArgKeyword = "--files"
val ConfTypeArgKeyword = "--conf-type"
val LocalConfPathArgKeyword = "--local-conf-path"
Expand All @@ -193,6 +195,7 @@ object JobSubmitterConstants {
MainClassKeyword,
FlinkMainJarUriArgKeyword,
FlinkPubSubJarUriArgKeyword,
AdditionalJarsUriArgKeyword,
LocalConfPathArgKeyword,
OriginalModeArgKeyword,
FilesArgKeyword,
Expand Down