Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 22 additions & 15 deletions api/python/ai/chronon/repo/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
import boto3

from ai.chronon.logger import get_logger
from ai.chronon.repo.constants import ROUTES, ZIPLINE_DIRECTORY
from ai.chronon.repo.constants import ROUTES, S3_PREFIX, ZIPLINE_DIRECTORY
from ai.chronon.repo.default_runner import Runner
from ai.chronon.repo.utils import (
JobType,
check_call,
extract_filename_from_path,
get_customer_id,
split_date_range,
)

Expand All @@ -32,18 +31,31 @@

class AwsRunner(Runner):
def __init__(self, args):
self._args = args
super().__init__(args)

# Validate bucket names start with
for bucket in [self.zipline_artifacts_bucket, self.zipline_warehouse_bucket]:
if not bucket.startswith(S3_PREFIX):
raise ValueError(
f"Invalid bucket name: {bucket}. "
f"Bucket names must start with '{S3_PREFIX}'."
)

Comment on lines +38 to +44
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Guard None and strip s3:// before AWS SDK calls
bucket may be None, and boto3 expects bare bucket names (no s3://). bucket.startswith() on None raises, and passing the prefixed value to boto3 will fail.

-        for bucket in [self.zipline_artifacts_bucket, self.zipline_warehouse_bucket]:
-            if not bucket.startswith(S3_PREFIX):
+        for bucket in [self.zipline_artifacts_bucket, self.zipline_warehouse_bucket]:
+            if bucket is None:
+                raise ValueError("Zipline bucket not provided.")
+            if not bucket.startswith(S3_PREFIX):
                 raise ValueError(
                     f"Invalid bucket name: {bucket}. "
                     f"Bucket names must start with '{S3_PREFIX}'."
                 )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for bucket in [self.zipline_artifacts_bucket, self.zipline_warehouse_bucket]:
if not bucket.startswith(S3_PREFIX):
raise ValueError(
f"Invalid bucket name: {bucket}. "
f"Bucket names must start with '{S3_PREFIX}'."
)
for bucket in [self.zipline_artifacts_bucket, self.zipline_warehouse_bucket]:
if bucket is None:
raise ValueError("Zipline bucket not provided.")
if not bucket.startswith(S3_PREFIX):
raise ValueError(
f"Invalid bucket name: {bucket}. "
f"Bucket names must start with '{S3_PREFIX}'."
)

aws_jar_path = AwsRunner.download_zipline_aws_jar(
ZIPLINE_DIRECTORY, get_customer_id(), args["version"], ZIPLINE_AWS_JAR_DEFAULT
destination_dir=ZIPLINE_DIRECTORY, version=args["version"], jar_name=ZIPLINE_AWS_JAR_DEFAULT,
bucket_name=self.zipline_artifacts_bucket
)
service_jar_path = AwsRunner.download_zipline_aws_jar(
ZIPLINE_DIRECTORY, get_customer_id(), args["version"], ZIPLINE_AWS_SERVICE_JAR
destination_dir=ZIPLINE_DIRECTORY, version=args["version"], jar_name=ZIPLINE_AWS_SERVICE_JAR,
bucket_name=self.zipline_artifacts_bucket
)
jar_path = (
f"{service_jar_path}:{aws_jar_path}" if args['mode'] == "fetch" else aws_jar_path
)
self.version = args.get("version", "latest")
self.set_jar_path(os.path.expanduser(jar_path))

super().__init__(args, os.path.expanduser(jar_path))

@staticmethod
def upload_s3_file(
Expand All @@ -61,11 +73,10 @@ def upload_s3_file(
raise RuntimeError(f"Failed to upload {source_file_name}: {str(e)}") from e

@staticmethod
def download_zipline_aws_jar(destination_dir: str, customer_id: str, version: str, jar_name: str):
def download_zipline_aws_jar(destination_dir: str, version: str, jar_name: str, bucket_name: str):
s3_client = boto3.client("s3")
destination_path = f"{destination_dir}/{jar_name}"
source_key_name = f"release/{version}/jars/{jar_name}"
bucket_name = f"zipline-artifacts-{customer_id}"

are_identical = (
AwsRunner.compare_s3_and_local_file_hashes(
Expand Down Expand Up @@ -140,7 +151,7 @@ def generate_emr_submitter_args(
job_type: JobType = JobType.SPARK,
local_files_to_upload: List[str] = None,
):
customer_warehouse_bucket_name = f"zipline-warehouse-{get_customer_id()}"
customer_warehouse_bucket_name = self.zipline_warehouse_bucket
s3_files = []
for source_file in local_files_to_upload:
# upload to `metadata` folder
Expand All @@ -155,27 +166,23 @@ def generate_emr_submitter_args(

# we also want the additional-confs included here. it should already be in the bucket

zipline_artifacts_bucket_prefix = "s3://zipline-artifacts"

s3_files.append(
f"{zipline_artifacts_bucket_prefix}-{get_customer_id()}/confs/additional-confs.yaml"
f"{self.zipline_artifacts_bucket}/confs/additional-confs.yaml"
)

s3_file_args = ",".join(s3_files)

# include jar uri. should also already be in the bucket
jar_uri = (
f"{zipline_artifacts_bucket_prefix}-{get_customer_id()}"
+ f"/release/{self.version}/jars/{ZIPLINE_AWS_JAR_DEFAULT}"
f"{self.zipline_artifacts_bucket}/release/{self.version}/jars/{ZIPLINE_AWS_JAR_DEFAULT}"
)

final_args = "{user_args} --jar-uri={jar_uri} --job-type={job_type} --main-class={main_class}"

if job_type == JobType.FLINK:
main_class = "ai.chronon.flink.FlinkJob"
flink_jar_uri = (
f"{zipline_artifacts_bucket_prefix}-{get_customer_id()}"
+ f"/jars/{ZIPLINE_AWS_FLINK_JAR_DEFAULT}"
f"{self.zipline_artifacts_bucket}/jars/{ZIPLINE_AWS_FLINK_JAR_DEFAULT}"
)
return (
final_args.format(
Expand Down
6 changes: 6 additions & 0 deletions api/python/ai/chronon/repo/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,9 @@ def __str__(self):
# arg keywords
ONLINE_CLASS_ARG = "online_class"
ONLINE_JAR_ARG = "online_jar"

ZIPLINE_ARTIFACTS_BUCKET_ENV_KEY = "ZIPLINE_ARTIFACTS_BUCKET"
ZIPLINE_WAREHOUSE_BUCKET_ENV_KEY = "ZIPLINE_WAREHOUSE_BUCKET"

GCS_PREFIX = "gs://"
S3_PREFIX = "s3://"
15 changes: 13 additions & 2 deletions api/python/ai/chronon/repo/default_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@
ROUTES,
SPARK_MODES,
UNIVERSAL_ROUTES,
ZIPLINE_ARTIFACTS_BUCKET_ENV_KEY,
ZIPLINE_WAREHOUSE_BUCKET_ENV_KEY,
RunMode,
)


class Runner:
def __init__(self, args, jar_path):
def __init__(self, args):
self.jar_path = None
self.repo = args["repo"]
self.conf = args["conf"]
self.local_abs_conf_path = os.path.realpath(os.path.join(self.repo, self.conf))
Expand Down Expand Up @@ -81,7 +84,6 @@ def __init__(self, args, jar_path):
if "parallelism" in args and args["parallelism"]
else 1
)
self.jar_path = jar_path

self.args = args["args"] if args["args"] else ""
self.app_name = args["app_name"]
Expand All @@ -96,6 +98,15 @@ def __init__(self, args, jar_path):
self.spark_submit = args["spark_submit_path"]
self.list_apps_cmd = args["list_apps"]

self.zipline_artifacts_bucket = (args.get("zipline_artifacts_bucket")
or os.environ.get(ZIPLINE_ARTIFACTS_BUCKET_ENV_KEY))
self.zipline_warehouse_bucket = (args.get("zipline_warehouse_bucket")
or os.environ.get(ZIPLINE_WAREHOUSE_BUCKET_ENV_KEY))

Comment on lines +101 to +105
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add bucket name validation.

Validate that the bucket names have correct prefixes (S3_PREFIX or GCS_PREFIX).

 self.zipline_artifacts_bucket = (args.get("zipline_artifacts_bucket")
                                  or os.environ.get(ZIPLINE_ARTIFACTS_BUCKET_ENV_KEY))
 self.zipline_warehouse_bucket = (args.get("zipline_warehouse_bucket")
                                  or os.environ.get(ZIPLINE_WAREHOUSE_BUCKET_ENV_KEY))
+
+if self.zipline_artifacts_bucket and not (self.zipline_artifacts_bucket.startswith(S3_PREFIX) or 
+                                          self.zipline_artifacts_bucket.startswith(GCS_PREFIX)):
+    raise ValueError(f"Artifacts bucket must start with {S3_PREFIX} or {GCS_PREFIX}")
+
+if self.zipline_warehouse_bucket and not (self.zipline_warehouse_bucket.startswith(S3_PREFIX) or 
+                                          self.zipline_warehouse_bucket.startswith(GCS_PREFIX)):
+    raise ValueError(f"Warehouse bucket must start with {S3_PREFIX} or {GCS_PREFIX}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self.zipline_artifacts_bucket = (args.get("zipline_artifacts_bucket")
or os.environ.get(ZIPLINE_ARTIFACTS_BUCKET_ENV_KEY))
self.zipline_warehouse_bucket = (args.get("zipline_warehouse_bucket")
or os.environ.get(ZIPLINE_WAREHOUSE_BUCKET_ENV_KEY))
self.zipline_artifacts_bucket = (args.get("zipline_artifacts_bucket")
or os.environ.get(ZIPLINE_ARTIFACTS_BUCKET_ENV_KEY))
self.zipline_warehouse_bucket = (args.get("zipline_warehouse_bucket")
or os.environ.get(ZIPLINE_WAREHOUSE_BUCKET_ENV_KEY))
if self.zipline_artifacts_bucket and not (
self.zipline_artifacts_bucket.startswith(S3_PREFIX)
or self.zipline_artifacts_bucket.startswith(GCS_PREFIX)
):
raise ValueError(
f"Artifacts bucket must start with {S3_PREFIX} or {GCS_PREFIX}"
)
if self.zipline_warehouse_bucket and not (
self.zipline_warehouse_bucket.startswith(S3_PREFIX)
or self.zipline_warehouse_bucket.startswith(GCS_PREFIX)
):
raise ValueError(
f"Warehouse bucket must start with {S3_PREFIX} or {GCS_PREFIX}"
)

def set_jar_path(self, jar_path):
self.jar_path = jar_path


def run_spark_streaming(self):
# streaming mode
self.app_name = self.app_name.replace(
Expand Down
55 changes: 33 additions & 22 deletions api/python/ai/chronon/repo/gcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@
from google.cloud import storage

from ai.chronon.logger import get_logger
from ai.chronon.repo.constants import ROUTES, ZIPLINE_DIRECTORY
from ai.chronon.repo.constants import GCS_PREFIX, ROUTES, ZIPLINE_DIRECTORY
from ai.chronon.repo.default_runner import Runner
from ai.chronon.repo.utils import (
JobType,
check_call,
check_output,
extract_filename_from_path,
get_customer_id,
get_environ_arg,
retry_decorator,
split_date_range,
Expand All @@ -33,26 +32,36 @@

class GcpRunner(Runner):
def __init__(self, args):
self._args = args
super().__init__(args)

# Validate bucket names start with "gs://"
for bucket in [self.zipline_artifacts_bucket, self.zipline_warehouse_bucket]:
if not bucket.startswith(GCS_PREFIX):
raise ValueError(
f"Invalid bucket name: {bucket}. "
f"Bucket names must start with '{GCS_PREFIX}'."
)
Comment on lines +38 to +44
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Same None / prefix issue as AWS
Handle missing buckets before .startswith() to avoid AttributeError.


gcp_jar_path = GcpRunner.download_zipline_dataproc_jar(
ZIPLINE_DIRECTORY,
get_customer_id(),
args["version"],
ZIPLINE_GCP_JAR_DEFAULT,
destination_dir=ZIPLINE_DIRECTORY,
version=args["version"],
jar_name=ZIPLINE_GCP_JAR_DEFAULT,
bucket_name=self.zipline_artifacts_bucket
)
service_jar_path = GcpRunner.download_zipline_dataproc_jar(
ZIPLINE_DIRECTORY,
get_customer_id(),
args["version"],
ZIPLINE_GCP_SERVICE_JAR,
destination_dir=ZIPLINE_DIRECTORY,
version=args["version"],
jar_name=ZIPLINE_GCP_SERVICE_JAR,
bucket_name=self.zipline_artifacts_bucket
)
jar_path = (
f"{service_jar_path}:{gcp_jar_path}"
if args["mode"] == "fetch"
else gcp_jar_path
)
self.set_jar_path(os.path.expanduser(jar_path))

self._args = args
super().__init__(args, os.path.expanduser(jar_path))

@staticmethod
def get_gcp_project_id() -> str:
Expand Down Expand Up @@ -92,6 +101,9 @@ def upload_gcs_blob(bucket_name, source_file_name, destination_blob_name):

try:
storage_client = storage.Client(project=GcpRunner.get_gcp_project_id())
# Remove the "gs://" prefix if it exists as this API doesn't support the bucket having the prefix
if bucket_name.startswith(GCS_PREFIX):
bucket_name = bucket_name[len(GCS_PREFIX) :]
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(destination_blob_name)
blob.upload_from_filename(source_file_name)
Expand All @@ -109,6 +121,11 @@ def get_gcs_file_hash(bucket_name: str, blob_name: str) -> str:
Get the hash of a file stored in Google Cloud Storage.
"""
storage_client = storage.Client(project=GcpRunner.get_gcp_project_id())

# Remove the "gs://" prefix if it exists as this API doesn't support the bucket having the prefix
if bucket_name.startswith(GCS_PREFIX):
bucket_name = bucket_name[len(GCS_PREFIX):]

bucket = storage_client.bucket(bucket_name)
blob = bucket.get_blob(blob_name)

Expand Down Expand Up @@ -171,9 +188,8 @@ def compare_gcs_and_local_file_hashes(

@staticmethod
def download_zipline_dataproc_jar(
destination_dir: str, customer_id: str, version: str, jar_name: str
destination_dir: str, version: str, jar_name: str, bucket_name: str
):
bucket_name = f"zipline-artifacts-{customer_id}"

source_blob_name = f"release/{version}/jars/{jar_name}"
destination_path = f"{destination_dir}/{jar_name}"
Expand Down Expand Up @@ -204,7 +220,6 @@ def generate_dataproc_submitter_args(
job_type: JobType = JobType.SPARK,
local_files_to_upload: List[str] = None,
):
customer_warehouse_bucket_name = f"zipline-warehouse-{get_customer_id()}"

if local_files_to_upload is None:
local_files_to_upload = []
Expand All @@ -217,20 +232,17 @@ def generate_dataproc_submitter_args(
)
gcs_files.append(
GcpRunner.upload_gcs_blob(
customer_warehouse_bucket_name, source_file, destination_file_path
self.zipline_warehouse_bucket, source_file, destination_file_path
)
)

# we also want the additional-confs included here. it should already be in the bucket

zipline_artifacts_bucket_prefix = "gs://zipline-artifacts"

gcs_file_args = ",".join(gcs_files)

# include jar uri. should also already be in the bucket
jar_uri = (
f"{zipline_artifacts_bucket_prefix}-{get_customer_id()}"
+ f"/release/{version}/jars/{ZIPLINE_GCP_JAR_DEFAULT}"
f"{self.zipline_artifacts_bucket}/release/{version}/jars/{ZIPLINE_GCP_JAR_DEFAULT}"
)

final_args = "{user_args} --jar-uri={jar_uri} --job-type={job_type} --main-class={main_class}"
Expand All @@ -240,8 +252,7 @@ def generate_dataproc_submitter_args(
if job_type == JobType.FLINK:
main_class = "ai.chronon.flink.FlinkJob"
flink_jar_uri = (
f"{zipline_artifacts_bucket_prefix}-{get_customer_id()}"
+ f"/release/{version}/jars/{ZIPLINE_GCP_FLINK_JAR_DEFAULT}"
f"{self.zipline_artifacts_bucket}/release/{version}/jars/{ZIPLINE_GCP_FLINK_JAR_DEFAULT}"
)
return (
final_args.format(
Expand Down
20 changes: 16 additions & 4 deletions api/python/ai/chronon/repo/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,17 @@
)
from ai.chronon.repo.constants import (
APP_NAME_TEMPLATE,
AWS,
CLOUD_PROVIDER_KEYWORD,
GCP,
MODE_ARGS,
ONLINE_CLASS_ARG,
ONLINE_JAR_ARG,
ONLINE_MODES,
RENDER_INFO_DEFAULT_SCRIPT,
ZIPLINE_ARTIFACTS_BUCKET_ENV_KEY,
ZIPLINE_DIRECTORY,
ZIPLINE_WAREHOUSE_BUCKET_ENV_KEY,
RunMode,
)
from ai.chronon.repo.default_runner import Runner
Expand Down Expand Up @@ -82,7 +85,6 @@ def set_defaults(ctx):
if ctx.params.get(key) is None and value is not None:
ctx.params[key] = value


def _set_package_version():
try:
package_version = ver("zipline-ai")
Expand Down Expand Up @@ -177,9 +179,15 @@ def _set_package_version():
help="Validate the catalyst util Spark expression evaluation logic",
)
@click.option(
"--validate-rows", default="10000", help="Number of rows to run the validation on"
"--validate-rows", default="10000", help="Number of rows to run the validation on"
)
@click.option("--join-part-name", help="Name of the join part to use for join-part-job")
@click.option("--zipline-artifacts-bucket",
help=f"Bucket containing Zipline artifacts. "
f"Can also set via environment variable {ZIPLINE_ARTIFACTS_BUCKET_ENV_KEY}")
@click.option("--zipline-warehouse-bucket",
help=f"Bucket containing Zipline warehouse data. "
f"Can also set via environment variable {ZIPLINE_WAREHOUSE_BUCKET_ENV_KEY}")
@click.pass_context
def main(
ctx,
Expand Down Expand Up @@ -214,6 +222,8 @@ def main(
validate,
validate_rows,
join_part_name,
zipline_artifacts_bucket,
zipline_warehouse_bucket,
):
unknown_args = ctx.args
click.echo("Running with args: {}".format(ctx.params))
Expand All @@ -230,15 +240,17 @@ def main(
if not cloud_provider:
# Support open source chronon runs
if chronon_jar:
Runner(ctx.params, os.path.expanduser(chronon_jar)).run()
default_runner = Runner(ctx.params)
default_runner.jar_path = os.path.expanduser(chronon_jar)
default_runner.run()
else:
raise ValueError("Jar path is not set.")
elif cloud_provider.upper() == GCP:
ctx.params[ONLINE_JAR_ARG] = ZIPLINE_GCP_JAR_DEFAULT
ctx.params[ONLINE_CLASS_ARG] = ZIPLINE_GCP_ONLINE_CLASS_DEFAULT
ctx.params[CLOUD_PROVIDER_KEYWORD] = cloud_provider
GcpRunner(ctx.params).run()
elif cloud_provider.upper() == "AWS":
elif cloud_provider.upper() == AWS:
ctx.params[ONLINE_JAR_ARG] = ZIPLINE_AWS_JAR_DEFAULT
ctx.params[ONLINE_CLASS_ARG] = ZIPLINE_AWS_ONLINE_CLASS_DEFAULT
ctx.params[CLOUD_PROVIDER_KEYWORD] = cloud_provider
Expand Down
6 changes: 5 additions & 1 deletion api/python/test/canary/teams.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
"GCP_REGION": "us-central1",
"GCP_DATAPROC_CLUSTER_NAME": "zipline-canary-cluster",
"GCP_BIGTABLE_INSTANCE_ID": "zipline-canary-instance",
"ZIPLINE_ARTIFACTS_BUCKET": "gs://zipline-artifacts-dev",
"ZIPLINE_WAREHOUSE_BUCKET": "gs://zipline-warehouse-dev",
},
),
conf=ConfigProperties(
Expand Down Expand Up @@ -92,7 +94,7 @@
RunMode.BACKFILL: {
"spark.chronon.backfill_cloud_provider": "gcp", # dummy test config
}
}
},
),
)

Expand All @@ -102,6 +104,8 @@
common={
"CLOUD_PROVIDER": "aws",
"CUSTOMER_ID": "dev",
"ZIPLINE_ARTIFACTS_BUCKET": "s3://zipline-artifacts-dev",
"ZIPLINE_WAREHOUSE_BUCKET": "s3://zipline-warehouse-dev",
}
),
)
Loading