Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 18 additions & 22 deletions api/python/ai/chronon/repo/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import multiprocessing
import os
import time
from typing import List
import uuid
from urllib.parse import urlparse

import crcmod
Expand Down Expand Up @@ -61,6 +61,7 @@ def __init__(self, args):
)

self._args = args
self.job_id = str(uuid.uuid4())

super().__init__(args, os.path.expanduser(jar_path))

Expand Down Expand Up @@ -242,35 +243,35 @@ def generate_dataproc_submitter_args(
version: str,
customer_artifact_prefix: str,
job_type: JobType = JobType.SPARK,
local_files_to_upload: List[str] = None,
metadata_conf_path: str = None,
):

parsed = urlparse(customer_artifact_prefix)
source_blob_name = parsed.path.lstrip("/")

if local_files_to_upload is None:
local_files_to_upload = []

gcs_files = []
for source_file in local_files_to_upload:
# upload to `metadata` folder

# upload to `metadata` folder
if metadata_conf_path:
destination_file_path = os.path.join(
source_blob_name,
"metadata",
f"{extract_filename_from_path(source_file)}"
self.job_id,
f"{extract_filename_from_path(metadata_conf_path)}"
)
gcs_files.append(
GcpRunner.upload_gcs_blob(
get_customer_warehouse_bucket(), source_file, destination_file_path
get_customer_warehouse_bucket(), metadata_conf_path, destination_file_path
)
)

gcs_file_args = ",".join(gcs_files)
release_prefix = os.path.join(customer_artifact_prefix, "release", version, "jars")

# include jar uri. should also already be in the bucket
jar_uri = os.path.join(release_prefix, f"{ZIPLINE_GCP_JAR_DEFAULT}")

final_args = "{user_args} --jar-uri={jar_uri} --job-type={job_type} --main-class={main_class} --zipline-version={zipline_version}"
final_args = "{user_args} --jar-uri={jar_uri} --job-type={job_type} --main-class={main_class} --zipline-version={zipline_version} --job-id={job_id}"

if job_type == JobType.FLINK:
main_class = "ai.chronon.flink.FlinkJob"
Expand All @@ -282,6 +283,7 @@ def generate_dataproc_submitter_args(
job_type=job_type.value,
main_class=main_class,
zipline_version=self._version,
job_id=self.job_id,
)
+ f" --flink-main-jar-uri={flink_jar_uri}"
)
Expand All @@ -294,6 +296,7 @@ def generate_dataproc_submitter_args(
job_type=job_type.value,
main_class=main_class,
zipline_version=self._version,
job_id=self.job_id,
) + (f" --files={gcs_file_args}" if gcs_file_args else "")
else:
raise ValueError(f"Invalid job type: {job_type}")
Expand Down Expand Up @@ -387,15 +390,12 @@ def run(self):
"--partition-names" in args
), "Must specify a list of `--partition-names=schema.table/pk1=pv1/pk2=pv2"

local_files_to_upload_to_gcs = (
[os.path.join(self.repo, self.conf)] if self.conf else []
)
dataproc_args = self.generate_dataproc_submitter_args(
# for now, self.conf is the only local file that requires uploading to gcs
local_files_to_upload=local_files_to_upload_to_gcs,
user_args=self._gen_final_args(),
version=self._version,
customer_artifact_prefix=self._remote_artifact_prefix,
metadata_conf_path=str(os.path.join(self.repo, self.conf)) if self.conf else None
)
command = f"java -cp {self.jar_path} {DATAPROC_ENTRY} {dataproc_args}"
command_list.append(command)
Expand All @@ -405,9 +405,6 @@ def run(self):
command = self.run_dataproc_flink_streaming()
command_list.append(command)
else:
local_files_to_upload_to_gcs = (
[os.path.join(self.repo, self.conf)] if self.conf else []
)
if self.parallelism > 1:
assert self.start_ds is not None and self.ds is not None, (
"To use parallelism, please specify --start-ds and --end-ds to "
Expand Down Expand Up @@ -438,11 +435,11 @@ def run(self):
)

dataproc_args = self.generate_dataproc_submitter_args(
local_files_to_upload=local_files_to_upload_to_gcs,
# for now, self.conf is the only local file that requires uploading to gcs
user_args=user_args,
version=self._version,
customer_artifact_prefix=self._remote_artifact_prefix
customer_artifact_prefix=self._remote_artifact_prefix,
metadata_conf_path=str(os.path.join(self.repo, self.conf)) if self.conf else None,
)
command = (
f"java -cp {self.jar_path} {DATAPROC_ENTRY} {dataproc_args}"
Expand All @@ -467,11 +464,10 @@ def run(self):
),
)
dataproc_args = self.generate_dataproc_submitter_args(
# for now, self.conf is the only local file that requires uploading to gcs
local_files_to_upload=local_files_to_upload_to_gcs,
user_args=user_args,
version=self._version,
customer_artifact_prefix=self._remote_artifact_prefix
customer_artifact_prefix=self._remote_artifact_prefix,
metadata_conf_path=str(os.path.join(self.repo, self.conf)) if self.conf else None
)
command = f"java -cp {self.jar_path} {DATAPROC_ENTRY} {dataproc_args}"
command_list.append(command)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import org.apache.hadoop.fs.Path
import org.json4s._
import org.json4s.jackson.JsonMethods._
import org.yaml.snakeyaml.Yaml
import java.util.UUID

import scala.io.Source
import scala.jdk.CollectionConverters._
Expand Down Expand Up @@ -376,7 +375,9 @@ object DataprocSubmitter {

val metadataName = Option(JobSubmitter.getMetadata(args).get.getName).getOrElse("")

val jobId = UUID.randomUUID().toString
val jobId = JobSubmitter
.getArgValue(args, JobIdArgKeyword)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably just make this required then, and fail if it's not passed.

.getOrElse(throw new Exception("Missing required argument: " + JobIdArgKeyword))

val submissionProps = jobType match {
case TypeSparkJob =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ class DataprocSubmitterTest extends AnyFlatSpec with MockitoSugar {
s"$LocalConfPathArgKeyword=${path.toAbsolutePath.toString}",
s"$ConfTypeArgKeyword=group_bys",
s"$OriginalModeArgKeyword=backfill",
s"$ZiplineVersionArgKeyword=0.1.0"
s"$ZiplineVersionArgKeyword=0.1.0",
s"$JobIdArgKeyword=job-id"
)
)

Expand Down Expand Up @@ -160,7 +161,8 @@ class DataprocSubmitterTest extends AnyFlatSpec with MockitoSugar {
s"$GroupByNameArgKeyword=$groupByName",
s"$StreamingManifestPathArgKeyword=$manifestBucketPath",
s"$StreamingCheckpointPathArgKeyword=$flinkCheckpointUri",
s"$StreamingLatestSavepointArgKeyword"
s"$StreamingLatestSavepointArgKeyword",
s"$JobIdArgKeyword=job-id"
)
)

Expand Down Expand Up @@ -201,7 +203,8 @@ class DataprocSubmitterTest extends AnyFlatSpec with MockitoSugar {
s"$GroupByNameArgKeyword=$groupByName",
s"$StreamingManifestPathArgKeyword=$manifestBucketPath",
s"$StreamingCheckpointPathArgKeyword=$flinkCheckpointUri",
s"$StreamingNoSavepointArgKeyword"
s"$StreamingNoSavepointArgKeyword",
s"$JobIdArgKeyword=job-id"
)
)

Expand Down Expand Up @@ -245,7 +248,8 @@ class DataprocSubmitterTest extends AnyFlatSpec with MockitoSugar {
s"$GroupByNameArgKeyword=$groupByName",
s"$StreamingManifestPathArgKeyword=$manifestBucketPath",
s"$StreamingCustomSavepointArgKeyword=$userPassedSavepoint",
s"$StreamingCheckpointPathArgKeyword=$flinkCheckpointUri"
s"$StreamingCheckpointPathArgKeyword=$flinkCheckpointUri",
s"$JobIdArgKeyword=job-id"
)
)

Expand Down Expand Up @@ -473,7 +477,8 @@ class DataprocSubmitterTest extends AnyFlatSpec with MockitoSugar {
s"$GroupByNameArgKeyword=$groupByName",
s"$StreamingManifestPathArgKeyword=$manifestBucketPath",
s"$StreamingCheckpointPathArgKeyword=gs://zl-warehouse/flink-state/checkpoints",
s"$StreamingNoSavepointArgKeyword"
s"$StreamingNoSavepointArgKeyword",
s"$JobIdArgKeyword=job-id"
)
val submitter = mock[DataprocSubmitter]

Expand Down Expand Up @@ -526,7 +531,8 @@ class DataprocSubmitterTest extends AnyFlatSpec with MockitoSugar {
s"$GroupByNameArgKeyword=$groupByName",
s"$StreamingManifestPathArgKeyword=$manifestBucketPath",
s"$StreamingCheckpointPathArgKeyword=$flinkCheckpointUri",
s"$StreamingLatestSavepointArgKeyword"
s"$StreamingLatestSavepointArgKeyword",
s"$JobIdArgKeyword=job-id"
)
val submitter = mock[DataprocSubmitter]

Expand Down Expand Up @@ -588,7 +594,8 @@ class DataprocSubmitterTest extends AnyFlatSpec with MockitoSugar {
s"$GroupByNameArgKeyword=$groupByName",
s"$StreamingManifestPathArgKeyword=$manifestBucketPath",
s"$StreamingCheckpointPathArgKeyword=$flinkCheckpointUri",
s"$StreamingCustomSavepointArgKeyword=gs://zl-warehouse/flink-state/checkpoints/1234/chk-12"
s"$StreamingCustomSavepointArgKeyword=gs://zl-warehouse/flink-state/checkpoints/1234/chk-12",
s"$JobIdArgKeyword=job-id"
)
val submitter = mock[DataprocSubmitter]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ object JobSubmitterConstants {
val StreamingCustomSavepointArgKeyword = "--custom-savepoint"
val StreamingNoSavepointArgKeyword = "--no-savepoint"

val JobIdArgKeyword = "--job-id"

val SharedInternalArgs: Set[String] = Set(
JarUriArgKeyword,
JobTypeArgKeyword,
Expand All @@ -167,7 +169,8 @@ object JobSubmitterConstants {
StreamingCustomSavepointArgKeyword,
StreamingNoSavepointArgKeyword,
StreamingCheckpointPathArgKeyword,
StreamingVersionCheckDeploy
StreamingVersionCheckDeploy,
JobIdArgKeyword
)

val GcpBigtableInstanceIdEnvVar = "GCP_BIGTABLE_INSTANCE_ID"
Expand Down