-
Notifications
You must be signed in to change notification settings - Fork 8
Refactor run.py to not hardcode against zipline-artifacts or zipline-warehouse buckets but configurable #658
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
68b15ec
867ab6e
e516323
164f6bf
345a5c7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -12,12 +12,15 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||
| ROUTES, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| SPARK_MODES, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| UNIVERSAL_ROUTES, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ZIPLINE_ARTIFACTS_BUCKET_ENV_KEY, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ZIPLINE_WAREHOUSE_BUCKET_ENV_KEY, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| RunMode, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| class Runner: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__(self, args, jar_path): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| def __init__(self, args): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.jar_path = None | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.repo = args["repo"] | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.conf = args["conf"] | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.local_abs_conf_path = os.path.realpath(os.path.join(self.repo, self.conf)) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -81,7 +84,6 @@ def __init__(self, args, jar_path): | |||||||||||||||||||||||||||||||||||||||||||||||||
| if "parallelism" in args and args["parallelism"] | ||||||||||||||||||||||||||||||||||||||||||||||||||
| else 1 | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.jar_path = jar_path | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| self.args = args["args"] if args["args"] else "" | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.app_name = args["app_name"] | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -96,6 +98,15 @@ def __init__(self, args, jar_path): | |||||||||||||||||||||||||||||||||||||||||||||||||
| self.spark_submit = args["spark_submit_path"] | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.list_apps_cmd = args["list_apps"] | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| self.zipline_artifacts_bucket = (args.get("zipline_artifacts_bucket") | ||||||||||||||||||||||||||||||||||||||||||||||||||
| or os.environ.get(ZIPLINE_ARTIFACTS_BUCKET_ENV_KEY)) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.zipline_warehouse_bucket = (args.get("zipline_warehouse_bucket") | ||||||||||||||||||||||||||||||||||||||||||||||||||
| or os.environ.get(ZIPLINE_WAREHOUSE_BUCKET_ENV_KEY)) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+101
to
+105
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Add bucket name validation. Validate that the bucket names have correct prefixes (S3_PREFIX or GCS_PREFIX). self.zipline_artifacts_bucket = (args.get("zipline_artifacts_bucket")
or os.environ.get(ZIPLINE_ARTIFACTS_BUCKET_ENV_KEY))
self.zipline_warehouse_bucket = (args.get("zipline_warehouse_bucket")
or os.environ.get(ZIPLINE_WAREHOUSE_BUCKET_ENV_KEY))
+
+if self.zipline_artifacts_bucket and not (self.zipline_artifacts_bucket.startswith(S3_PREFIX) or
+ self.zipline_artifacts_bucket.startswith(GCS_PREFIX)):
+ raise ValueError(f"Artifacts bucket must start with {S3_PREFIX} or {GCS_PREFIX}")
+
+if self.zipline_warehouse_bucket and not (self.zipline_warehouse_bucket.startswith(S3_PREFIX) or
+ self.zipline_warehouse_bucket.startswith(GCS_PREFIX)):
+ raise ValueError(f"Warehouse bucket must start with {S3_PREFIX} or {GCS_PREFIX}")📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def set_jar_path(self, jar_path): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.jar_path = jar_path | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def run_spark_streaming(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # streaming mode | ||||||||||||||||||||||||||||||||||||||||||||||||||
| self.app_name = self.app_name.replace( | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,14 +8,13 @@ | |
| from google.cloud import storage | ||
|
|
||
| from ai.chronon.logger import get_logger | ||
| from ai.chronon.repo.constants import ROUTES, ZIPLINE_DIRECTORY | ||
| from ai.chronon.repo.constants import GCS_PREFIX, ROUTES, ZIPLINE_DIRECTORY | ||
| from ai.chronon.repo.default_runner import Runner | ||
| from ai.chronon.repo.utils import ( | ||
| JobType, | ||
| check_call, | ||
| check_output, | ||
| extract_filename_from_path, | ||
| get_customer_id, | ||
| get_environ_arg, | ||
| retry_decorator, | ||
| split_date_range, | ||
|
|
@@ -33,26 +32,36 @@ | |
|
|
||
| class GcpRunner(Runner): | ||
| def __init__(self, args): | ||
| self._args = args | ||
| super().__init__(args) | ||
|
|
||
| # Validate bucket names start with "gs://" | ||
| for bucket in [self.zipline_artifacts_bucket, self.zipline_warehouse_bucket]: | ||
| if not bucket.startswith(GCS_PREFIX): | ||
| raise ValueError( | ||
| f"Invalid bucket name: {bucket}. " | ||
| f"Bucket names must start with '{GCS_PREFIX}'." | ||
| ) | ||
|
Comment on lines
+38
to
+44
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Same |
||
|
|
||
| gcp_jar_path = GcpRunner.download_zipline_dataproc_jar( | ||
| ZIPLINE_DIRECTORY, | ||
| get_customer_id(), | ||
| args["version"], | ||
| ZIPLINE_GCP_JAR_DEFAULT, | ||
| destination_dir=ZIPLINE_DIRECTORY, | ||
| version=args["version"], | ||
| jar_name=ZIPLINE_GCP_JAR_DEFAULT, | ||
| bucket_name=self.zipline_artifacts_bucket | ||
| ) | ||
| service_jar_path = GcpRunner.download_zipline_dataproc_jar( | ||
| ZIPLINE_DIRECTORY, | ||
| get_customer_id(), | ||
| args["version"], | ||
| ZIPLINE_GCP_SERVICE_JAR, | ||
| destination_dir=ZIPLINE_DIRECTORY, | ||
| version=args["version"], | ||
| jar_name=ZIPLINE_GCP_SERVICE_JAR, | ||
| bucket_name=self.zipline_artifacts_bucket | ||
| ) | ||
| jar_path = ( | ||
| f"{service_jar_path}:{gcp_jar_path}" | ||
| if args["mode"] == "fetch" | ||
| else gcp_jar_path | ||
| ) | ||
| self.set_jar_path(os.path.expanduser(jar_path)) | ||
|
|
||
| self._args = args | ||
| super().__init__(args, os.path.expanduser(jar_path)) | ||
|
|
||
| @staticmethod | ||
| def get_gcp_project_id() -> str: | ||
|
|
@@ -92,6 +101,9 @@ def upload_gcs_blob(bucket_name, source_file_name, destination_blob_name): | |
|
|
||
| try: | ||
| storage_client = storage.Client(project=GcpRunner.get_gcp_project_id()) | ||
| # Remove the "gs://" prefix if it exists as this API doesn't support the bucket having the prefix | ||
| if bucket_name.startswith(GCS_PREFIX): | ||
| bucket_name = bucket_name[len(GCS_PREFIX) :] | ||
| bucket = storage_client.bucket(bucket_name) | ||
| blob = bucket.blob(destination_blob_name) | ||
| blob.upload_from_filename(source_file_name) | ||
|
|
@@ -109,6 +121,11 @@ def get_gcs_file_hash(bucket_name: str, blob_name: str) -> str: | |
| Get the hash of a file stored in Google Cloud Storage. | ||
| """ | ||
| storage_client = storage.Client(project=GcpRunner.get_gcp_project_id()) | ||
|
|
||
| # Remove the "gs://" prefix if it exists as this API doesn't support the bucket having the prefix | ||
| if bucket_name.startswith(GCS_PREFIX): | ||
| bucket_name = bucket_name[len(GCS_PREFIX):] | ||
|
|
||
| bucket = storage_client.bucket(bucket_name) | ||
| blob = bucket.get_blob(blob_name) | ||
|
|
||
|
|
@@ -171,9 +188,8 @@ def compare_gcs_and_local_file_hashes( | |
|
|
||
| @staticmethod | ||
| def download_zipline_dataproc_jar( | ||
| destination_dir: str, customer_id: str, version: str, jar_name: str | ||
| destination_dir: str, version: str, jar_name: str, bucket_name: str | ||
| ): | ||
| bucket_name = f"zipline-artifacts-{customer_id}" | ||
|
|
||
| source_blob_name = f"release/{version}/jars/{jar_name}" | ||
| destination_path = f"{destination_dir}/{jar_name}" | ||
|
|
@@ -204,7 +220,6 @@ def generate_dataproc_submitter_args( | |
| job_type: JobType = JobType.SPARK, | ||
| local_files_to_upload: List[str] = None, | ||
| ): | ||
| customer_warehouse_bucket_name = f"zipline-warehouse-{get_customer_id()}" | ||
|
|
||
| if local_files_to_upload is None: | ||
| local_files_to_upload = [] | ||
|
|
@@ -217,20 +232,17 @@ def generate_dataproc_submitter_args( | |
| ) | ||
| gcs_files.append( | ||
| GcpRunner.upload_gcs_blob( | ||
| customer_warehouse_bucket_name, source_file, destination_file_path | ||
| self.zipline_warehouse_bucket, source_file, destination_file_path | ||
| ) | ||
| ) | ||
|
|
||
| # we also want the additional-confs included here. it should already be in the bucket | ||
|
|
||
| zipline_artifacts_bucket_prefix = "gs://zipline-artifacts" | ||
|
|
||
| gcs_file_args = ",".join(gcs_files) | ||
|
|
||
| # include jar uri. should also already be in the bucket | ||
| jar_uri = ( | ||
| f"{zipline_artifacts_bucket_prefix}-{get_customer_id()}" | ||
| + f"/release/{version}/jars/{ZIPLINE_GCP_JAR_DEFAULT}" | ||
| f"{self.zipline_artifacts_bucket}/release/{version}/jars/{ZIPLINE_GCP_JAR_DEFAULT}" | ||
| ) | ||
|
|
||
| final_args = "{user_args} --jar-uri={jar_uri} --job-type={job_type} --main-class={main_class}" | ||
|
|
@@ -240,8 +252,7 @@ def generate_dataproc_submitter_args( | |
| if job_type == JobType.FLINK: | ||
| main_class = "ai.chronon.flink.FlinkJob" | ||
| flink_jar_uri = ( | ||
| f"{zipline_artifacts_bucket_prefix}-{get_customer_id()}" | ||
| + f"/release/{version}/jars/{ZIPLINE_GCP_FLINK_JAR_DEFAULT}" | ||
| f"{self.zipline_artifacts_bucket}/release/{version}/jars/{ZIPLINE_GCP_FLINK_JAR_DEFAULT}" | ||
| ) | ||
| return ( | ||
| final_args.format( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard
Noneand strips3://before AWS SDK callsbucketmay beNone, and boto3 expects bare bucket names (nos3://).bucket.startswith()onNoneraises, and passing the prefixed value to boto3 will fail.📝 Committable suggestion