Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
f941a6f
Connect run.py to DataprocSubmitter.scala so that offline jobs can be…
david-zlai Jan 8, 2025
a6b154b
put this back.
david-zlai Jan 8, 2025
e3b0c48
fix lint stuff.
david-zlai Jan 8, 2025
7472f45
add some more notes on why we only take the filename.
david-zlai Jan 8, 2025
61eba99
rename
david-zlai Jan 8, 2025
a8525bc
fix python linting.
david-zlai Jan 8, 2025
67d2159
Adding google-cloud-storage to dev.in and running pip-compile-multi …
david-zlai Jan 8, 2025
49b60a7
pin zipp
david-zlai Jan 9, 2025
ad3207c
try removing certifi
david-zlai Jan 9, 2025
f9ef45f
upgrade zipp to 3.21.0 even higher.
david-zlai Jan 9, 2025
957afad
remove zipp.
david-zlai Jan 9, 2025
8d0257a
try again with snyk
david-zlai Jan 9, 2025
0096de2
pin also in base.
david-zlai Jan 9, 2025
d7c82c3
downgrade google cloud core
david-zlai Jan 9, 2025
f3da7d7
downgrade google-cloud-core.
david-zlai Jan 9, 2025
f96d062
one more attempt
david-zlai Jan 9, 2025
3d1d6ff
PR comments.
david-zlai Jan 9, 2025
fc40e7d
more changes.
david-zlai Jan 9, 2025
01ba185
cleanup
david-zlai Jan 9, 2025
96a748a
fix delete
david-zlai Jan 9, 2025
263899c
format
david-zlai Jan 9, 2025
fe13f57
more PR comments.
david-zlai Jan 9, 2025
e5282e0
fix deps
david-zlai Jan 9, 2025
a1463cd
python formatting
david-zlai Jan 9, 2025
047c21e
actually fix formatting
david-zlai Jan 9, 2025
903c23f
fix project id
david-zlai Jan 9, 2025
431b7e1
constant around dataproc entry
david-zlai Jan 9, 2025
20141ea
download dataproc jar once only
david-zlai Jan 9, 2025
14ec11b
remove print
david-zlai Jan 9, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 170 additions & 27 deletions api/py/ai/chronon/repo/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from google.cloud import storage
import argparse
import json
import logging
Expand All @@ -25,6 +26,7 @@
import re
import subprocess
import time
from typing import List
import xml.etree.ElementTree as ET
from datetime import datetime, timedelta

Expand Down Expand Up @@ -193,11 +195,11 @@ def download_only_once(url, path, skip_download=False):

@retry_decorator(retries=3, backoff=50)
def download_jar(
version,
jar_type="uber",
release_tag=None,
spark_version="2.4.0",
skip_download=False,
version,
jar_type="uber",
release_tag=None,
spark_version="2.4.0",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3.5.1 right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uhm, not sure. it's been working so far with what we've been doing...

i didn't make any changes here. just indented

Copy link
Contributor

@nikhil-zlai nikhil-zlai Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for now we can keep it as is - we will edit versions later.

skip_download=False,
):
assert (
spark_version in SUPPORTED_SPARK
Expand Down Expand Up @@ -523,31 +525,90 @@ def run(self):
self.start_ds, self.ds, self.parallelism
)
for start_ds, end_ds in date_ranges:
if not args.dataproc:
command = (
"bash {script} --class ai.chronon.spark.Driver {jar} {subcommand} {args} "
+ "{additional_args}"
).format(
script=self.spark_submit,
jar=self.jar_path,
subcommand=ROUTES[self.conf_type][self.mode],
args=self._gen_final_args(start_ds=start_ds, end_ds=end_ds),
additional_args=os.environ.get(
"CHRONON_CONFIG_ADDITIONAL_ARGS", ""
),
)
command_list.append(command)
else:
# we'll always download the jar for now so that we can pull in any fixes or latest changes
dataproc_jar = download_dataproc_jar(get_customer_id())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

python has libraries to handle cleanup and everything. You can simplify this code by:

import tempfile

with tempfile.TemporaryDirectory() as temp_dir:
        jar_path = download_dataproc_jar(temp_dir, get_customer_id())

       ...

No explicit cleanup needed once you exit the contextmanager!

dataproc_entry = "ai.chronon.integrations.cloud_gcp.DataprocSubmitter"

user_args = (
"{subcommand} {args} {additional_args}"
).format(
subcommand=ROUTES[self.conf_type][self.mode],
args=self._gen_final_args(start_ds=self.start_ds,
end_ds=end_ds,
# overriding the conf here because we only want the filename,
# not the full path. When we upload this to GCS, the full path
# does get reflected on GCS. But when we include the gcs file
# path as part of dataproc, the file is copied to root and
# not the complete path is copied.
override_conf_path=self.conf.split("/")[-1]),
additional_args=os.environ.get(
"CHRONON_CONFIG_ADDITIONAL_ARGS", ""
),
)

dataproc_command = generate_dataproc_submitter_args(
local_files_to_upload_to_gcs=[self.conf],
# for now, self.conf is the only local file that requires uploading to gcs
user_args=user_args
)

command = f"java -cp {dataproc_jar} {dataproc_entry} {dataproc_command}"
command_list.append(command)
else:
if not args.dataproc:
command = (
"bash {script} --class ai.chronon.spark.Driver {jar} {subcommand} {args} {additional_args}"
).format(
script=self.spark_submit,
jar=self.jar_path,
subcommand=ROUTES[self.conf_type][self.mode],
args=self._gen_final_args(start_ds=start_ds, end_ds=end_ds),
args=self._gen_final_args(self.start_ds),
additional_args=os.environ.get(
"CHRONON_CONFIG_ADDITIONAL_ARGS", ""
),
)
command_list.append(command)
else:
command = (
"bash {script} --class ai.chronon.spark.Driver {jar} {subcommand} {args} {additional_args}"
).format(
script=self.spark_submit,
jar=self.jar_path,
subcommand=ROUTES[self.conf_type][self.mode],
args=self._gen_final_args(self.start_ds),
additional_args=os.environ.get(
"CHRONON_CONFIG_ADDITIONAL_ARGS", ""
),
)
command_list.append(command)
else:
dataproc_jar = download_dataproc_jar(get_customer_id())
dataproc_entry = "ai.chronon.integrations.cloud_gcp.DataprocSubmitter"
user_args = (
"{subcommand} {args} {additional_args}"
).format(
subcommand=ROUTES[self.conf_type][self.mode],
args=self._gen_final_args(start_ds=self.start_ds,
# overriding the conf here because we only want the filename,
# not the full path. When we upload this to GCS, the full path
# does get reflected on GCS. But when we include the gcs file
# path as part of dataproc, the file is copied to root and
# not the complete path is copied.
override_conf_path=self.conf.split("/")[-1]),
additional_args=os.environ.get(
"CHRONON_CONFIG_ADDITIONAL_ARGS", ""
),
)

dataproc_command = generate_dataproc_submitter_args(
# for now, self.conf is the only local file that requires uploading to gcs
local_files_to_upload_to_gcs=[self.conf],
user_args=user_args
)
command = f"java -cp {dataproc_jar} {dataproc_entry} {dataproc_command}"
command_list.append(command)
if len(command_list) > 1:
# parallel backfill mode
with multiprocessing.Pool(processes=int(self.parallelism)) as pool:
Expand All @@ -560,9 +621,9 @@ def run(self):
elif len(command_list) == 1:
check_call(command_list[0])

def _gen_final_args(self, start_ds=None, end_ds=None):
def _gen_final_args(self, start_ds=None, end_ds=None, override_conf_path=None):
base_args = MODE_ARGS[self.mode].format(
conf_path=self.conf,
conf_path=override_conf_path if override_conf_path else self.conf,
ds=end_ds if end_ds else self.ds,
online_jar=self.online_jar,
online_class=self.online_class,
Expand Down Expand Up @@ -628,6 +689,87 @@ def set_defaults(parser):
)


def get_customer_id() -> str:
customer_id = os.environ.get('CUSTOMER_ID')
if not customer_id:
raise ValueError('Please set CUSTOMER_ID environment variable')
return customer_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add validation for required GCP environment variables.

Validate all required GCP environment variables upfront.

+def validate_gcp_env():
+    required_vars = [
+        'GCP_PROJECT',
+        'GCP_REGION',
+        'GCP_DATAPROC_CLUSTER_NAME',
+        'CUSTOMER_ID'
+    ]
+    missing = [var for var in required_vars if not os.environ.get(var)]
+    if missing:
+        raise ValueError(f'Missing required environment variables: {", ".join(missing)}')

 def get_customer_id() -> str:
+    validate_gcp_env()
     customer_id = os.environ.get('CUSTOMER_ID')
-    if not customer_id:
-        raise ValueError('Please set CUSTOMER_ID environment variable')
     return customer_id
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def get_customer_id() -> str:
customer_id = os.environ.get('CUSTOMER_ID')
if not customer_id:
raise ValueError('Please set CUSTOMER_ID environment variable')
return customer_id
def validate_gcp_env():
required_vars = [
'GCP_PROJECT',
'GCP_REGION',
'GCP_DATAPROC_CLUSTER_NAME',
'CUSTOMER_ID'
]
missing = [var for var in required_vars if not os.environ.get(var)]
if missing:
raise ValueError(f'Missing required environment variables: {", ".join(missing)}')
def get_customer_id() -> str:
validate_gcp_env()
customer_id = os.environ.get('CUSTOMER_ID')
return customer_id



def generate_dataproc_submitter_args(local_files_to_upload_to_gcs: List[str], user_args: str):
# TODO: change this when gcs://zipline-warehouse-etsy is available
bucket_name = "zipline-jars"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Avoid hardcoding bucket names

Make bucket_name configurable for flexibility.


gcs_files = []
for f in local_files_to_upload_to_gcs:
gcs_files.append(upload_gcs_blob(bucket_name, f, f))

# we also want the additional-confs included here
gcs_files.append(f"gs://zipline-artifacts-{get_customer_id()}/confs/additional-confs.yaml")

gcs_file_args = ",".join(gcs_files)

# include chronon jar uri
chronon_jar_uri = f"gs://zipline-artifacts-{get_customer_id()}/jars/cloud_gcp-assembly-0.1.0-SNAPSHOT.jar"

final_args = (f"{user_args} --additional-conf-path=additional-confs.yaml --gcs_files={gcs_file_args} "
f"--chronon_jar_uri={chronon_jar_uri}")

return final_args


def download_dataproc_jar(customer_id: str):
destination_file_name = "/tmp/cloud_gcp_submitter-assembly-0.1.0-SNAPSHOT.jar"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should write this + other details like project id etc to a well know zipline dir on the user's machine (can punt to a follow up pr)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think we should clean this jar up though after running it, @david-zlai would that be easy to do?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a cleanup step at the end

print("Downloading dataproc submitter jar from GCS...")
bucket_name = f"zipline-artifacts-{customer_id}"
source_blob_name = "jars/cloud_gcp_submitter-assembly-0.1.0-SNAPSHOT.jar"
download_gcs_blob(bucket_name, source_blob_name, destination_file_name)
return destination_file_name


def download_gcs_blob(bucket_name, source_blob_name, destination_file_name):
"""Downloads a blob from the bucket."""
storage_client = storage.Client()

bucket = storage_client.bucket(bucket_name)

# Construct a client side representation of a blob.
# Note `Bucket.blob` differs from `Bucket.get_blob` as it doesn't retrieve
# any content from Google Cloud Storage. As we don't need additional data,
# using `Bucket.blob` is preferred here.
blob = bucket.blob(source_blob_name)
blob.download_to_filename(destination_file_name)

print(
"Downloaded storage object {} from bucket {} to local file {}.".format(
source_blob_name, bucket_name, destination_file_name
)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add retry mechanism and proper error handling for GCS operations.

GCS operations should be resilient to transient failures.

+@retry_decorator(retries=3, backoff=20)
 def download_gcs_blob(bucket_name, source_blob_name, destination_file_name):
     """Downloads a blob from the bucket."""
-    storage_client = storage.Client()
+    try:
+        storage_client = storage.Client()
+        bucket = storage_client.bucket(bucket_name)
+        blob = bucket.blob(source_blob_name)
+        blob.download_to_filename(destination_file_name)
+    except Exception as e:
+        raise RuntimeError(f"Failed to download {source_blob_name}: {str(e)}")

+@retry_decorator(retries=3, backoff=20)
 def upload_gcs_blob(bucket_name, source_file_name, destination_blob_name):
     """Uploads a file to the bucket."""
-    storage_client = storage.Client()
+    try:
+        storage_client = storage.Client()
+        bucket = storage_client.bucket(bucket_name)
+        blob = bucket.blob(destination_blob_name)
+        blob.upload_from_filename(source_file_name)
+    except Exception as e:
+        raise RuntimeError(f"Failed to upload {source_file_name}: {str(e)}")

Also applies to: 763-783



def upload_gcs_blob(bucket_name, source_file_name, destination_blob_name):
"""Uploads a file to the bucket."""

storage_client = storage.Client()
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(destination_blob_name)

# Optional: set a generation-match precondition to avoid potential race conditions
# and data corruptions. The request to upload is aborted if the object's
# generation number does not match your precondition. For a destination
# object that does not yet exist, set the if_generation_match precondition to 0.
# If the destination object already exists in your bucket, set instead a
# generation-match precondition using its generation number.
# generation_match_precondition = 0

blob.upload_from_filename(source_file_name)

print(
f"File {source_file_name} uploaded to {destination_blob_name} in bucket {bucket_name}."
)
return f"gs://{bucket_name}/{destination_blob_name}"


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Submit various kinds of chronon jobs")
parser.add_argument(
Expand All @@ -642,27 +784,28 @@ def set_defaults(parser):
help="Running environment - default to be dev",
)
parser.add_argument("--mode", choices=MODE_ARGS.keys())
parser.add_argument("--dataproc", action="store_true", help="Run on dataproc")
parser.add_argument("--ds", help="the end partition to backfill the data")
parser.add_argument(
"--app-name", help="app name. Default to {}".format(APP_NAME_TEMPLATE)
)
parser.add_argument(
"--start-ds",
help="override the original start partition for a range backfill. "
"It only supports staging query, group by backfill and join jobs. "
"It could leave holes in your final output table due to the override date range.",
"It only supports staging query, group by backfill and join jobs. "
"It could leave holes in your final output table due to the override date range.",
)
parser.add_argument("--end-ds", help="the end ds for a range backfill")
parser.add_argument(
"--parallelism",
help="break down the backfill range into this number of tasks in parallel. "
"Please use it along with --start-ds and --end-ds and only in manual mode",
"Please use it along with --start-ds and --end-ds and only in manual mode",
)
parser.add_argument("--repo", help="Path to chronon repo")
parser.add_argument(
"--online-jar",
help="Jar containing Online KvStore & Deserializer Impl. "
+ "Used for streaming and metadata-upload mode.",
+ "Used for streaming and metadata-upload mode.",
)
parser.add_argument(
"--online-class",
Expand All @@ -679,7 +822,7 @@ def set_defaults(parser):
parser.add_argument(
"--online-jar-fetch",
help="Path to script that can pull online jar. "
+ "This will run only when a file doesn't exist at location specified by online_jar",
+ "This will run only when a file doesn't exist at location specified by online_jar",
)
parser.add_argument(
"--sub-help",
Expand All @@ -704,7 +847,7 @@ def set_defaults(parser):
parser.add_argument(
"--render-info",
help="Path to script rendering additional information of the given config. "
+ "Only applicable when mode is set to info",
+ "Only applicable when mode is set to info",
)
set_defaults(parser)
pre_parse_args, _ = parser.parse_known_args()
Expand Down
8 changes: 4 additions & 4 deletions api/py/requirements/base.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# SHA1:1d44bb5a0f927ef885e838e299990ba7ecd68dda
# SHA1:3079b5e84710de85b877bd986aa6167385f428c8
#
# This file is autogenerated by pip-compile-multi
# To update, run:
#
# pip-compile-multi
#
click==8.1.7
click==8.1.8
# via -r requirements/base.in
six==1.16.0
six==1.17.0
# via thrift
thrift==0.20.0
thrift==0.13.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

our thrift needs to be at 0.21 - or compile will break i think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me set this in base.in then

# via -r requirements/base.in
4 changes: 4 additions & 0 deletions api/py/requirements/dev.in
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@ black
pre-commit
isort
autoflake
zipp==3.19.1
importlib-metadata==8.4.0
google-cloud-storage==2.19.0

Loading
Loading