Skip to content

Commit 0207e63

Browse files
authored
Merge branch 'main' into vz--partition_column_in_source_v2
2 parents b447d8f + a493d35 commit 0207e63

File tree

6 files changed

+56
-33
lines changed

6 files changed

+56
-33
lines changed

api/py/ai/chronon/repo/run.py

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from google.cloud import storage
2121
import base64
2222
import click
23-
import google_crc32c
23+
import crcmod
2424
import json
2525
import logging
2626
import multiprocessing
@@ -707,7 +707,30 @@ def run(self):
707707
)
708708
pool.map(check_call, command_list)
709709
elif len(command_list) == 1:
710-
check_call(command_list[0])
710+
if self.dataproc:
711+
output = check_output(command_list[0]).decode("utf-8").split("\n")
712+
print(*output, sep="\n")
713+
714+
dataproc_submitter_id_str = "Dataproc submitter job id"
715+
716+
dataproc_submitter_logs = [s for s in output if dataproc_submitter_id_str in s]
717+
if dataproc_submitter_logs:
718+
log = dataproc_submitter_logs[0]
719+
job_id = log[log.index(dataproc_submitter_id_str) + len(dataproc_submitter_id_str) + 1:]
720+
try:
721+
print("""
722+
<-----------------------------------------------------------------------------------
723+
------------------------------------------------------------------------------------
724+
DATAPROC LOGS
725+
------------------------------------------------------------------------------------
726+
------------------------------------------------------------------------------------>
727+
""")
728+
check_call(f"gcloud dataproc jobs wait {job_id} --region={get_gcp_region_id()}")
729+
except Exception:
730+
# swallow since this is just for tailing logs
731+
pass
732+
else:
733+
check_call(command_list[0])
711734

712735
def _gen_final_args(self, start_ds=None, end_ds=None, override_conf_path=None, **kwargs):
713736
base_args = MODE_ARGS[self.mode].format(
@@ -794,27 +817,27 @@ def set_defaults(ctx):
794817
ctx.params[key] = value
795818

796819

820+
def get_environ_arg(env_name) -> str:
821+
value = os.environ.get(env_name)
822+
if not value:
823+
raise ValueError(f"Please set {env_name} environment variable")
824+
return value
825+
826+
797827
def get_customer_id() -> str:
798-
customer_id = os.environ.get('CUSTOMER_ID')
799-
if not customer_id:
800-
raise ValueError('Please set CUSTOMER_ID environment variable')
801-
return customer_id
828+
return get_environ_arg('CUSTOMER_ID')
802829

803830

804831
def get_gcp_project_id() -> str:
805-
gcp_project_id = os.environ.get('GCP_PROJECT_ID')
806-
if not gcp_project_id:
807-
raise ValueError(
808-
'Please set GCP_PROJECT_ID environment variable')
809-
return gcp_project_id
832+
return get_environ_arg('GCP_PROJECT_ID')
810833

811834

812835
def get_gcp_bigtable_instance_id() -> str:
813-
gcp_bigtable_instance_id = os.environ.get('GCP_BIGTABLE_INSTANCE_ID')
814-
if not gcp_bigtable_instance_id:
815-
raise ValueError(
816-
'Please set GCP_BIGTABLE_INSTANCE_ID environment variable')
817-
return gcp_bigtable_instance_id
836+
return get_environ_arg('GCP_BIGTABLE_INSTANCE_ID')
837+
838+
839+
def get_gcp_region_id() -> str:
840+
return get_environ_arg('GCP_REGION')
818841

819842

820843
def generate_dataproc_submitter_args(user_args: str, job_type: DataprocJobType = DataprocJobType.SPARK,
@@ -947,15 +970,15 @@ def get_local_file_hash(file_path: str) -> str:
947970
Returns:
948971
Base64-encoded string of the file's CRC32C hash
949972
"""
950-
crc32c = google_crc32c.Checksum()
973+
crc32c_hash = crcmod.predefined.Crc('crc-32c')
951974

952975
with open(file_path, "rb") as f:
953976
# Read the file in chunks to handle large files efficiently
954977
for chunk in iter(lambda: f.read(4096), b""):
955-
crc32c.update(chunk)
978+
crc32c_hash.update(chunk)
956979

957980
# Convert to base64 to match GCS format
958-
return base64.b64encode(crc32c.digest()).decode('utf-8')
981+
return base64.b64encode(crc32c_hash.digest()).decode('utf-8')
959982

960983

961984
def compare_gcs_and_local_file_hashes(bucket_name: str, blob_name: str, local_file_path: str) -> bool:

api/py/requirements/base.in

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ thrift==0.21.0
33
google-cloud-storage==2.19.0
44
google-cloud-bigquery-storage
55
pyspark==3.5.4
6-
sqlglot
6+
sqlglot
7+
crcmod==1.7

api/py/requirements/base.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
1-
# SHA1:fe8b0a1dc101ff0b0ffa9a959160d459b7f7d0e3
1+
# SHA1:68652bbb7f3ec5c449c5d85307085c0c94bc4da3
22
#
33
# This file is autogenerated by pip-compile-multi
44
# To update, run:
55
#
66
# pip-compile-multi
77
#
8+
89
cachetools==5.5.0
910
# via google-auth
1011
charset-normalizer==3.4.1
1112
# via requests
1213
click==8.1.8
1314
# via -r requirements/base.in
15+
crcmod==1.7
16+
# via -r requirements/base.in
1417
google-api-core==2.24.0
1518
# via
1619
# google-cloud-core

cloud_gcp/src/main/scala/ai/chronon/integrations/cloud_gcp/DataprocSubmitter.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ object DataprocSubmitter {
290290
finalArgs: _*
291291
)
292292
println("Dataproc submitter job id: " + jobId)
293+
println(s"Safe to exit. Follow the job status at: https://console.cloud.google.com/dataproc/jobs/${jobId}")
293294
}
294295
}
295296

distribution/run_zipline_quickstart.sh

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ set -xo pipefail
77
WORKING_DIR=$1
88
cd $WORKING_DIR
99

10+
1011
GREEN='\033[0;32m'
1112
RED='\033[0;31m'
1213

@@ -45,7 +46,6 @@ function check_dataproc_job_state() {
4546
echo "No job id available to check. Exiting."
4647
exit 1
4748
fi
48-
gcloud dataproc jobs wait $JOB_ID --region=us-central1
4949
echo -e "${GREEN} <<<<<<<<<<<<<<<<-----------------JOB STATUS----------------->>>>>>>>>>>>>>>>>\033[0m"
5050
JOB_STATE=$(gcloud dataproc jobs describe $JOB_ID --region=us-central1 --format=flattened | grep "status.state:")
5151
echo $JOB_STATE
@@ -62,33 +62,33 @@ zipline compile --conf=group_bys/quickstart/purchases.py
6262

6363
echo -e "${GREEN}<<<<<.....................................BACKFILL.....................................>>>>>\033[0m"
6464
touch tmp_backfill.out
65-
zipline run --conf production/group_bys/quickstart/purchases.v1_test --dataproc 2>&1 | tee /dev/tty tmp_backfill.out
65+
zipline run --conf production/group_bys/quickstart/purchases.v1_test --dataproc 2>&1 | tee tmp_backfill.out
6666
BACKFILL_JOB_ID=$(cat tmp_backfill.out | grep "$DATAPROC_SUBMITTER_ID_STR" | cut -d " " -f5)
6767
check_dataproc_job_state $BACKFILL_JOB_ID
6868

6969
echo -e "${GREEN}<<<<<.....................................GROUP-BY-UPLOAD.....................................>>>>>\033[0m"
7070
touch tmp_gbu.out
71-
zipline run --mode upload --conf production/group_bys/quickstart/purchases.v1_test --ds 2023-12-01 --dataproc 2>&1 | tee /dev/tty tmp_gbu.out
71+
zipline run --mode upload --conf production/group_bys/quickstart/purchases.v1_test --ds 2023-12-01 --dataproc 2>&1 | tee tmp_gbu.out
7272
GBU_JOB_ID=$(cat tmp_gbu.out | grep "$DATAPROC_SUBMITTER_ID_STR" | cut -d " " -f5)
7373
check_dataproc_job_state $GBU_JOB_ID
7474

7575
# Need to wait for upload to finish
7676
echo -e "${GREEN}<<<<<.....................................UPLOAD-TO-KV.....................................>>>>>\033[0m"
7777
touch tmp_upload_to_kv.out
78-
zipline run --mode upload-to-kv --conf production/group_bys/quickstart/purchases.v1_test --partition-string=2023-12-01 --dataproc 2>&1 | tee /dev/tty tmp_upload_to_kv.out
78+
zipline run --mode upload-to-kv --conf production/group_bys/quickstart/purchases.v1_test --partition-string=2023-12-01 --dataproc 2>&1 | tee tmp_upload_to_kv.out
7979
UPLOAD_TO_KV_JOB_ID=$(cat tmp_upload_to_kv.out | grep "$DATAPROC_SUBMITTER_ID_STR" | cut -d " " -f5)
8080
check_dataproc_job_state $UPLOAD_TO_KV_JOB_ID
8181

8282
echo -e "${GREEN}<<<<< .....................................METADATA-UPLOAD.....................................>>>>>\033[0m"
8383
touch tmp_metadata_upload.out
84-
zipline run --mode metadata-upload --conf production/group_bys/quickstart/purchases.v1_test --dataproc 2>&1 | tee /dev/tty tmp_metadata_upload.out
84+
zipline run --mode metadata-upload --conf production/group_bys/quickstart/purchases.v1_test --dataproc 2>&1 | tee tmp_metadata_upload.out
8585
METADATA_UPLOAD_JOB_ID=$(cat tmp_metadata_upload.out | grep "$DATAPROC_SUBMITTER_ID_STR" | cut -d " " -f5)
8686
check_dataproc_job_state $METADATA_UPLOAD_JOB_ID
8787

8888
# Need to wait for upload-to-kv to finish
8989
echo -e "${GREEN}<<<<<.....................................FETCH.....................................>>>>>\033[0m"
9090
touch tmp_fetch.out
91-
zipline run --mode fetch --conf-type group_bys --name quickstart/purchases.v1_test -k '{"user_id":"5"}' 2>&1 | tee /dev/tty tmp_fetch.out | grep -q purchase_price_average_14d
91+
zipline run --mode fetch --conf-type group_bys --name quickstart/purchases.v1_test -k '{"user_id":"5"}' 2>&1 | tee tmp_fetch.out | grep -q purchase_price_average_14d
9292
cat tmp_fetch.out | grep purchase_price_average_14d
9393
# check if exit code of previous is 0
9494
if [ $? -ne 0 ]; then

spark/src/main/scala/ai/chronon/spark/JoinBase.scala

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ abstract class JoinBase(val joinConfCloned: api.Join,
284284
Map(user -> user_name, user_name -> user)
285285
the below logic will first rename the conflicted column with some random suffix and update the rename map
286286
*/
287-
lazy val renamedLeftRawDf = {
287+
lazy val renamedLeftDf = {
288288
val columns = skewFilteredLeft.columns.flatMap { column =>
289289
if (joinPart.leftToRight.contains(column)) {
290290
Some(col(column).as(joinPart.leftToRight(column)))
@@ -299,11 +299,6 @@ abstract class JoinBase(val joinConfCloned: api.Join,
299299

300300
lazy val shiftedPartitionRange = unfilledTimeRange.toPartitionRange.shift(-1)
301301

302-
val renamedLeftDf = renamedLeftRawDf.select(renamedLeftRawDf.columns.map {
303-
case c if c == tableUtils.partitionColumn =>
304-
date_format(renamedLeftRawDf.col(c), tableUtils.partitionFormat).as(c)
305-
case c => renamedLeftRawDf.col(c)
306-
}.toList: _*)
307302
val rightDf = (joinConfCloned.left.dataModel, joinPart.groupBy.dataModel, joinPart.groupBy.inferredAccuracy) match {
308303
case (Entities, Events, _) => partitionRangeGroupBy.snapshotEvents(unfilledRange)
309304
case (Entities, Entities, _) => partitionRangeGroupBy.snapshotEntities

0 commit comments

Comments
 (0)