From 04e900d6028600444a81b48d200a31dd1d24b9f2 Mon Sep 17 00:00:00 2001 From: camweston-stripe <116691078+camweston-stripe@users.noreply.github.com> Date: Thu, 24 Apr 2025 11:36:06 -0700 Subject: [PATCH 01/31] Initial cherry pick of CHIP-10 from stripe chronon repo --- api/py/ai/chronon/pyspark/__init__.py | 0 api/py/ai/chronon/pyspark/constants.py | 18 + api/py/ai/chronon/pyspark/databricks.py | 176 ++++ api/py/ai/chronon/pyspark/executables.py | 758 ++++++++++++++++++ api/py/ai/chronon/utils.py | 30 + api/py/requirements/base.in | 1 + api/py/requirements/base.txt | 20 +- .../scala/ai/chronon/spark/PySparkUtils.scala | 209 +++++ .../DatabricksConstantsNameProvider.scala | 12 + .../databricks/DatabricksTableUtils.scala | 117 +++ 10 files changed, 1330 insertions(+), 11 deletions(-) create mode 100644 api/py/ai/chronon/pyspark/__init__.py create mode 100644 api/py/ai/chronon/pyspark/constants.py create mode 100644 api/py/ai/chronon/pyspark/databricks.py create mode 100644 api/py/ai/chronon/pyspark/executables.py create mode 100644 spark/src/main/scala/ai/chronon/spark/PySparkUtils.scala create mode 100644 spark/src/main/scala/ai/chronon/spark/databricks/DatabricksConstantsNameProvider.scala create mode 100644 spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala diff --git a/api/py/ai/chronon/pyspark/__init__.py b/api/py/ai/chronon/pyspark/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/py/ai/chronon/pyspark/constants.py b/api/py/ai/chronon/pyspark/constants.py new file mode 100644 index 0000000000..5df155bfec --- /dev/null +++ b/api/py/ai/chronon/pyspark/constants.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +#-------------------------------------------------------------------------- +# Company Specific Constants +#-------------------------------------------------------------------------- + +PARTITION_COLUMN_FORMAT: str = '%Y%m%d' + +#-------------------------------------------------------------------------- +# Platform Specific Constants +#-------------------------------------------------------------------------- + +#-------------------------------------------------------------------------- +# Databricks Constants +#-------------------------------------------------------------------------- +DATABRICKS_OUTPUT_NAMESPACE: str = 'chronon_poc_usertables' +DATABRICKS_JVM_LOG_FILE: str = "/databricks/chronon_logfile.log" +DATABRICKS_ROOT_DIR_FOR_IMPORTED_FEATURES: str = "src" \ No newline at end of file diff --git a/api/py/ai/chronon/pyspark/databricks.py b/api/py/ai/chronon/pyspark/databricks.py new file mode 100644 index 0000000000..c5cad0ef2c --- /dev/null +++ b/api/py/ai/chronon/pyspark/databricks.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +import os +from typing_extensions import override +from typing import cast +from pyspark.sql import SparkSession +from pyspark.dbutils import DBUtils +from py4j.java_gateway import JavaObject + +from ai.chronon.pyspark.executables import ( + GroupByExecutable, + JoinExecutable, + PlatformInterface, +) + +from ai.chronon.api.ttypes import GroupBy, Join + +from ai.chronon.pyspark.constants import ( + DATABRICKS_OUTPUT_NAMESPACE, + DATABRICKS_JVM_LOG_FILE, + DATABRICKS_ROOT_DIR_FOR_IMPORTED_FEATURES, +) + +class DatabricksPlatform(PlatformInterface): + """ + Databricks-specific implementation of the platform interface. + """ + + + + def __init__(self, spark: SparkSession): + """ + Initialize Databricks-specific components. + + Args: + spark: The SparkSession to use + """ + super().__init__(spark) + self.dbutils: DBUtils = DBUtils(self.spark) + self.constants_provider: JavaObject = self.get_constants_provider() + self.table_utils: JavaObject = self.get_table_utils() + self.register_udfs() + + @override + def get_constants_provider(self) -> JavaObject: + """ + Get the Databricks constants provider. + + Returns: + A JavaObject representing the constants provider + """ + constants_provider: JavaObject = self.jvm.ai.chronon.spark.databricks.DatabricksConstantsNameProvider() + self.jvm.ai.chronon.api.Constants.initConstantNameProvider(constants_provider) + return constants_provider + + @override + def get_table_utils(self) -> JavaObject: + """ + Get the Databricks table utilities. + + Returns: + A JavaObject representing the table utilities + """ + return self.jvm.ai.chronon.spark.databricks.DatabricksTableUtils(self.java_spark_session) + + @override + def register_udfs(self) -> None: + """Register UDFs for Databricks.""" + pass + + @override + def get_executable_join_cls(self) -> type[JoinExecutable]: + """Get the Databricks-specific join executable class.""" + return DatabricksJoin + + @override + def start_log_capture(self, job_name: str) -> tuple[int, str]: + """ + Start capturing logs in Databricks. + + Args: + job_name: The name of the job for log headers + + Returns: + A tuple of (start_position, job_name) + """ + return (os.path.getsize(DATABRICKS_JVM_LOG_FILE), job_name) + + @override + def end_log_capture(self, capture_token: tuple[int, str]) -> None: + """ + End log capture and print logs in Databricks. + + Args: + capture_token: The token returned by start_log_capture + """ + start_position, job_name = capture_token + + print("\n\n", "*" * 10, f" BEGIN LOGS FOR {job_name} ", "*" * 10) + with open(DATABRICKS_JVM_LOG_FILE, "r") as file_handler: + _ = file_handler.seek(start_position) + print(file_handler.read()) + print("*" * 10, f" END LOGS FOR {job_name} ", "*" * 10, "\n\n") + + + def get_databricks_user(self) -> str: + """ + Get the current Databricks user. + + Returns: + The username of the current Databricks user + """ + user_email = self.dbutils.notebook.entry_point.getDbutils().notebook().getContext().userName().get() + return user_email.split('@')[0].lower() + + +class DatabricksGroupBy(GroupByExecutable): + """Class for executing GroupBy objects in Databricks.""" + + def __init__(self, group_by: GroupBy, spark_session: SparkSession): + """ + Initialize a GroupBy executor for Databricks. + + Args: + group_by: The GroupBy object to execute + spark_session: The SparkSession to use + """ + super().__init__(group_by, spark_session) + + self.obj: GroupBy = self.platform.set_metadata( + obj=self.obj, + mod_prefix=DATABRICKS_ROOT_DIR_FOR_IMPORTED_FEATURES, + name_prefix=cast(DatabricksPlatform, self.platform).get_databricks_user(), + output_namespace=DATABRICKS_OUTPUT_NAMESPACE + ) + + @override + def get_platform(self) -> PlatformInterface: + """ + Get the platform interface. + + Returns: + The Databricks platform interface + """ + return DatabricksPlatform(self.spark) + + +class DatabricksJoin(JoinExecutable): + """Class for executing Join objects in Databricks.""" + + def __init__(self, join: Join, spark_session: SparkSession): + """ + Initialize a Join executor for Databricks. + + Args: + join: The Join object to execute + spark_session: The SparkSession to use + """ + super().__init__(join, spark_session) + + self.obj: Join = self.platform.set_metadata( + obj=self.obj, + mod_prefix=DATABRICKS_ROOT_DIR_FOR_IMPORTED_FEATURES, + name_prefix=cast(DatabricksPlatform, self.platform).get_databricks_user(), + output_namespace=DATABRICKS_OUTPUT_NAMESPACE + ) + + @override + def get_platform(self) -> PlatformInterface: + """ + Get the platform interface. + + Returns: + The Databricks platform interface + """ + return DatabricksPlatform(self.spark) diff --git a/api/py/ai/chronon/pyspark/executables.py b/api/py/ai/chronon/pyspark/executables.py new file mode 100644 index 0000000000..f0469ae893 --- /dev/null +++ b/api/py/ai/chronon/pyspark/executables.py @@ -0,0 +1,758 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +import copy +from datetime import datetime, timedelta +from typing import TypeVar, Generic, cast, Any +from py4j.java_gateway import JavaObject, JVMView + +from pyspark.sql import DataFrame, SparkSession +from py4j.java_gateway import JavaObject + +from ai.chronon.api.ttypes import GroupBy, Join, JoinPart, JoinSource, Source, Query, StagingQuery +from ai.chronon.utils import set_name, get_max_window_for_gb_in_days, output_table_name +from ai.chronon.repo.serializer import thrift_simple_json +from ai.chronon.pyspark.constants import PARTITION_COLUMN_FORMAT + +# Define type variable for our executables +T = TypeVar('T', GroupBy, Join) + + +class PySparkExecutable(Generic[T], ABC): + """ + Abstract base class defining common functionality for executing features via PySpark. + This class provides shared initialization and utility methods but does not + define abstract execution methods, which are left to specialized subclasses. + """ + + def __init__(self, obj: T, spark_session: SparkSession): + """ + Initialize the executable with an object and SparkSession. + + Args: + obj: The object (GroupBy, Join, StagingQuery) to execute + spark_session: The SparkSession to use for execution + """ + self.obj: T = obj + self.spark: SparkSession = spark_session + self.jvm: JVMView = self.spark._jvm + self.platform: PlatformInterface = self.get_platform() + self.java_spark_session: JavaObject = self.spark._jsparkSession + self.default_start_date: str = (datetime.now() - timedelta(days=8)).strftime(PARTITION_COLUMN_FORMAT) + self.default_end_date: str = (datetime.now() - timedelta(days=1)).strftime(PARTITION_COLUMN_FORMAT) + + @abstractmethod + def get_platform(self) -> PlatformInterface: + """ + Get the platform interface for platform-specific operations. + + Returns: + A PlatformInterface instance + """ + pass + + + def _update_query_dates(self, query: Query, start_date: str, end_date: str) -> Query: + """ + Update start and end dates of a query. + + Args: + query: The query to update + start_date: The new start date + end_date: The new end date + + Returns: + The updated query + """ + query_copy = copy.deepcopy(query) + query_copy.startPartition = start_date + query_copy.endPartition = end_date + return query_copy + + def _update_source_dates(self, source: Source, start_date: str, end_date: str) -> Source: + """ + Update start and end dates of a source. + + Args: + source: The source to update + start_date: The new start date + end_date: The new end date + + Returns: + The updated source + """ + source_copy = copy.deepcopy(source) + if source_copy.events and source_copy.events.query: + source_copy.events.query = self._update_query_dates( + cast(Query, source_copy.events.query), start_date, end_date) + elif source_copy.entities and source_copy.entities.query: + source_copy.entities.query = self._update_query_dates( + cast(Query, source_copy.entities.query), start_date, end_date) + return source_copy + + def _execute_underlying_join_sources(self, group_bys: list[GroupBy], start_date: str, end_date: str, step_days: int) -> None: + """ + Execute underlying join sources. + + Args: + group_bys: List of GroupBy objects + start_date: Start date for execution + end_date: End date for execution + step_days: Number of days to process in each step + """ + + joins_to_execute: list[Join] = [] + join_sources_to_execute_start_dates: dict[str, str] = {} + + for group_by in group_bys: + group_by_join_sources: list[JoinSource] = [ + s.joinSource for s in cast(list[Source], group_by.sources) + if s.joinSource and not s.joinSource.outputTableNameOverride + ] + + if not group_by_join_sources: + continue + + + # Recall that records generated by the inner join are input events for the outer join + # Therefore in order to correctly aggregate the outer join, your inner join needs to be run from start_date - max_window_for_gb_in_days + max_window_for_gb_in_days: int = get_max_window_for_gb_in_days(group_by) + + shifted_start_date = ( + datetime.strptime(start_date, PARTITION_COLUMN_FORMAT) - + timedelta(days=max_window_for_gb_in_days) + ).strftime(PARTITION_COLUMN_FORMAT) + + for js in group_by_join_sources: + js_name: str | None = js.join.metaData.name + + if js_name is None: + raise ValueError(f"Join source {js} does not have a name. Was set_metadata called?") + + if js_name not in join_sources_to_execute_start_dates: + join_sources_to_execute_start_dates[js_name] = shifted_start_date + joins_to_execute.append(js.join) + else: + existing_start_date: str = join_sources_to_execute_start_dates[js_name] + join_sources_to_execute_start_dates[js_name] = min(shifted_start_date, existing_start_date) + + if not joins_to_execute: + return + + self.platform.log_operation(f"Executing {len(joins_to_execute)} Join Sources") + + for join in joins_to_execute: + j_start_date: str = join_sources_to_execute_start_dates[join.metaData.name] + + executable_join_cls: type[JoinExecutable] = self.platform.get_executable_join_cls() + + executable_join = executable_join_cls(join, self.spark) + + self.platform.log_operation(f"Executing Join Source {join.metaData.name} from {j_start_date} to {end_date}") + _ = executable_join.run(start_date=j_start_date, end_date=end_date, step_days=step_days) + + output_table_name_for_js: str = output_table_name(join, full_name=True) + self.platform.log_operation(f"Join Source {join.metaData.name} will be read from {output_table_name_for_js}") + + self.platform.log_operation("Finished executing Join Sources") + + def print_with_timestamp(self, message: str) -> None: + """Utility to print a message with timestamp.""" + current_utc_time = datetime.utcnow() + time_str = current_utc_time.strftime('[%Y-%m-%d %H:%M:%S UTC]') + print(f'{time_str} {message}') + + + def group_by_to_java(self, group_by: GroupBy, end_date: str) -> JavaObject: + """ + Convert GroupBy object to Java representation with updated S3 prefixes. + + Args: + group_by: The GroupBy object to convert + end_date: End date for execution + + Returns: + Java representation of the GroupBy + """ + json_representation: str = thrift_simple_json(group_by) + java_group_by: JavaObject = self.jvm.ai.chronon.spark.PySparkUtils.parseGroupBy(json_representation) + return self.jvm.ai.chronon.spark.S3Utils.readAndUpdateS3PrefixesForGroupBy( + java_group_by, end_date, self.java_spark_session + ) + + def join_to_java(self, join: Join, end_date: str) -> JavaObject: + """ + Convert Join object to Java representation with updated S3 prefixes. + + Args: + join: The Join object to convert + end_date: End date for execution + + Returns: + Java representation of the Join + """ + json_representation: str = thrift_simple_json(join) + java_join: JavaObject = self.jvm.ai.chronon.spark.PySparkUtils.parseJoin(json_representation) + return self.jvm.ai.chronon.spark.S3Utils.readAndUpdateS3PrefixesForJoin( + java_join, end_date, self.java_spark_session + ) + + +class GroupByExecutable(PySparkExecutable[GroupBy], ABC): + """Interface for executing GroupBy objects""" + + def _update_source_dates_for_group_by(self, group_by: GroupBy, start_date: str, end_date: str) -> GroupBy: + """ + Update start and end dates of sources in GroupBy. + + Args: + group_by: The GroupBy object to update + start_date: The new start date + end_date: The new end date + + Returns: + The updated GroupBy object + """ + if not group_by.sources: + return group_by + + for i, source in enumerate(group_by.sources): + group_by.sources[i] = self._update_source_dates(source, start_date, end_date) + return group_by + + def run(self, + start_date: str | None = None, + end_date: str | None = None, + step_days: int = 30, + skip_execution_of_underlying_join: bool = False) -> DataFrame: + """ + Execute the GroupBy object. + + Args: + start_date: Start date for the execution (format: YYYYMMDD) + end_date: End date for the execution (format: YYYYMMDD) + step_days: Number of days to process in each step + skip_execution_of_underlying_join: Whether to skip execution of underlying joins + + Returns: + DataFrame with the execution results + """ + + start_date: str = start_date or self.default_start_date + end_date: str = end_date or self.default_end_date + + self.platform.log_operation(f"Executing GroupBy {self.obj.metaData.name} from {start_date} to {end_date} with step_days {step_days}") + self.platform.log_operation(f"Skip Execution of Underlying Join Sources: {skip_execution_of_underlying_join}") + + if not skip_execution_of_underlying_join: + self._execute_underlying_join_sources(group_bys=[self.obj], start_date=start_date, end_date=end_date, step_days=step_days) + + # Prepare GroupBy for execution + group_by_to_execute: GroupBy = copy.deepcopy(self.obj) + group_by_to_execute.backfillStartDate = start_date + + # Update sources with correct dates + group_by_to_execute: GroupBy = self._update_source_dates_for_group_by(group_by_to_execute, start_date, end_date) + + # Get output table name + group_by_output_table: str = output_table_name(group_by_to_execute, full_name=True) + + # GroupBy backfills don't store the semantic hash as a property in the table the same way joins do. + # Therefore we drop the backfill table to avoid data quality issues. + self.platform.drop_table_if_exists(table_name=group_by_output_table) + + # Find starting point for log capture just before executing JVM calls + log_token = self.platform.start_log_capture(f"Run GroupBy: {self.obj.metaData.name}") + + try: + # Convert to Java GroupBy + java_group_by = self.group_by_to_java(group_by_to_execute, end_date) + # Execute GroupBy + result_df_scala: JavaObject = self.jvm.ai.chronon.spark.PySparkUtils.runGroupBy( + java_group_by, + end_date, + self.jvm.ai.chronon.spark.PySparkUtils.getIntOptional(str(step_days)), + self.platform.get_table_utils(), + self.platform.get_constants_provider() + ) + + result_df = DataFrame(result_df_scala, self.spark) + self.platform.end_log_capture(log_token) + self.platform.log_operation(f"GroupBy {self.obj.metaData.name} executed successfully and was written to {group_by_output_table}") + return result_df + except Exception as e: + self.platform.end_log_capture(log_token) + self.platform.log_operation(f"Execution failed for GroupBy {self.obj.metaData.name}: {str(e)}") + raise e + + def analyze(self, + start_date: str | None = None, + end_date: str | None = None, + enable_hitter_analysis: bool = False) -> None: + """ + Analyze the GroupBy object. + + Args: + start_date: Start date for analysis (format: YYYYMMDD) + end_date: End date for analysis (format: YYYYMMDD) + enable_hitter_analysis: Whether to enable hitter analysis + """ + start_date = start_date or self.default_start_date + end_date = end_date or self.default_end_date + + self.platform.log_operation(f"Analyzing GroupBy {self.obj.metaData.name} from {start_date} to {end_date}") + self.platform.log_operation(f"Enable Hitter Analysis: {enable_hitter_analysis}") + + # Prepare GroupBy for analysis + group_by_to_analyze: GroupBy = copy.deepcopy(self.obj) + + # Update sources with correct dates + group_by_to_analyze: GroupBy = self._update_source_dates_for_group_by(group_by_to_analyze, start_date, end_date) + + # Start log capture just before executing JVM calls + log_token = self.platform.start_log_capture(f"Analyze GroupBy: {self.obj.metaData.name}") + + try: + # Convert to Java GroupBy + java_group_by = self.group_by_to_java(group_by_to_analyze, end_date) + # Analyze GroupBy + self.jvm.ai.chronon.spark.PySparkUtils.analyzeGroupBy( + java_group_by, + start_date, + end_date, + enable_hitter_analysis, + self.platform.get_table_utils(), + self.platform.get_constants_provider() + ) + self.platform.end_log_capture(log_token) + self.platform.log_operation(f"GroupBy {self.obj.metaData.name} analyzed successfully") + except Exception as e: + self.platform.end_log_capture(log_token) + self.platform.log_operation(f"Analysis failed for GroupBy {self.obj.metaData.name}: {str(e)}") + raise e + + + def validate(self, + start_date: str | None = None, + end_date: str | None = None) -> None: + """ + Validate the GroupBy object. + + Args: + start_date: Start date for validation (format: YYYYMMDD) + end_date: End date for validation (format: YYYYMMDD) + """ + platform = self.get_platform() + start_date = start_date or self.default_start_date + end_date = end_date or self.default_end_date + + self.platform.log_operation(f"Validating GroupBy {self.obj.metaData.name} from {start_date} to {end_date}") + + # Prepare GroupBy for validation + group_by_to_validate = copy.deepcopy(self.obj) + + # Update sources with correct dates + group_by_to_validate: GroupBy = self._update_source_dates_for_group_by(group_by_to_validate, start_date, end_date) + + # Start log capture just before executing JVM calls + log_token = self.platform.start_log_capture(f"Validate GroupBy: {self.obj.metaData.name}") + + try: + # Convert to Java GroupBy + java_group_by: JavaObject = self.group_by_to_java(group_by_to_validate, end_date) + # Validate GroupBy + errors_list: JavaObject = self.jvm.ai.chronon.spark.PySparkUtils.validateGroupBy( + java_group_by, + start_date, + end_date, + self.platform.get_table_utils(), + self.platform.get_constants_provider() + ) + + self.platform.end_log_capture(log_token) + self.platform.handle_validation_errors(errors_list, f"GroupBy {self.obj.metaData.name}") + self.platform.log_operation(f"Validation for GroupBy {self.obj.metaData.name} has completed") + except Exception as e: + self.platform.log_operation(f"Validation failed for GroupBy {self.obj.metaData.name}: {str(e)}") + self.platform.end_log_capture(log_token) + raise e + + + + + + +class JoinExecutable(PySparkExecutable[Join], ABC): + """Interface for executing Join objects""" + + + def _update_source_dates_for_join_parts(self, join_parts: list[JoinPart], start_date: str, end_date: str) -> list[JoinPart]: + """ + Update start and end dates of sources in JoinParts. + + Args: + join_parts: List of JoinPart objects + start_date: The new start date + end_date: The new end date + + Returns: + The updated list of JoinPart objects + """ + if not join_parts: + return [] + + for jp in join_parts: + for i, source in enumerate(jp.groupBy.sources): + jp.groupBy.sources[i] = self._update_source_dates(source, start_date, end_date) + return join_parts + + def run(self, + start_date: str | None = None, + end_date: str | None = None, + step_days: int = 30, + skip_first_hole: bool = False, + sample_num_of_rows: int | None = None, + skip_execution_of_underlying_join: bool = False) -> DataFrame: + """ + Execute the Join object with Join-specific parameters. + + Args: + start_date: Start date for the execution (format: YYYYMMDD) + end_date: End date for the execution (format: YYYYMMDD) + step_days: Number of days to process in each step + skip_first_hole: Whether to skip the first hole in the join + sample_num_of_rows: Number of rows to sample (None for all) + skip_execution_of_underlying_join: Whether to skip execution of underlying joins + + Returns: + DataFrame with the execution results + """ + start_date = start_date or self.default_start_date + end_date = end_date or self.default_end_date + + self.platform.log_operation(f"Executing Join {self.obj.metaData.name} from {start_date} to {end_date} with step_days {step_days}") + self.platform.log_operation(f"Skip First Hole: {skip_first_hole}") + self.platform.log_operation(f"Sample Number of Rows: {sample_num_of_rows}") + self.platform.log_operation(f"Skip Execution of Underlying Join: {skip_execution_of_underlying_join}") + + # Prepare Join for execution + join_to_execute = copy.deepcopy(self.obj) + join_to_execute.left = self._update_source_dates(join_to_execute.left, start_date, end_date) + + if not skip_execution_of_underlying_join and self.obj.joinParts: + self._execute_underlying_join_sources([jp.groupBy for jp in join_to_execute.joinParts], start_date, end_date, step_days) + + + # Update join parts sources + join_to_execute.joinParts = self._update_source_dates_for_join_parts( + join_to_execute.joinParts, start_date, end_date + ) + + # Get output table name + join_output_table = output_table_name(join_to_execute, full_name=True) + + # Start log capture just before executing JVM calls + log_token = self.platform.start_log_capture(f"Run Join: {self.obj.metaData.name}") + + try: + # Convert to Java Join + java_join: JavaObject = self.join_to_java(join_to_execute, end_date) + # Execute Join + result_df_scala: JavaObject = self.jvm.ai.chronon.spark.PySparkUtils.runJoin( + java_join, + end_date, + self.jvm.ai.chronon.spark.PySparkUtils.getIntOptional(str(step_days)), + skip_first_hole, + self.jvm.ai.chronon.spark.PySparkUtils.getIntOptional(None if not sample_num_of_rows else str(sample_num_of_rows)), + self.platform.get_table_utils(), + self.platform.get_constants_provider() + ) + + result_df = DataFrame(result_df_scala, self.spark) + self.platform.end_log_capture(log_token) + self.platform.log_operation(f"Join {self.obj.metaData.name} executed successfully and was written to {join_output_table}") + return result_df + except Exception as e: + self.platform.end_log_capture(log_token) + self.platform.log_operation(f"Execution failed for Join {self.obj.metaData.name}: {str(e)}") + raise e + + def analyze(self, + start_date: str | None = None, + end_date: str | None = None, + enable_hitter_analysis: bool = False) -> None: + """ + Analyze the Join object. + + Args: + start_date: Start date for analysis (format: YYYYMMDD) + end_date: End date for analysis (format: YYYYMMDD) + enable_hitter_analysis: Whether to enable hitter analysis + """ + start_date: str = start_date or self.default_start_date + end_date: str = end_date or self.default_end_date + + self.platform.log_operation(f"Analyzing Join {self.obj.metaData.name} from {start_date} to {end_date}") + self.platform.log_operation(f"Enable Hitter Analysis: {enable_hitter_analysis}") + + # Prepare Join for analysis + join_to_analyze: Join = copy.deepcopy(self.obj) + join_to_analyze.left = self._update_source_dates(join_to_analyze.left, start_date, end_date) + + # Update join parts sources + join_to_analyze.joinParts = self._update_source_dates_for_join_parts( + join_to_analyze.joinParts, start_date, end_date + ) + + # Start log capture just before executing JVM calls + log_token = self.platform.start_log_capture(f"Analyze Join: {self.obj.metaData.name}") + + try: + # Convert to Java Join + java_join: JavaObject = self.join_to_java(join_to_analyze, end_date) + # Analyze Join + self.jvm.ai.chronon.spark.PySparkUtils.analyzeJoin( + java_join, + start_date, + end_date, + enable_hitter_analysis, + self.platform.get_table_utils(), + self.platform.get_constants_provider() + ) + self.platform.end_log_capture(log_token) + self.platform.log_operation(f"Join {self.obj.metaData.name} analyzed successfully") + + except Exception as e: + self.platform.end_log_capture(log_token) + self.platform.log_operation(f"Analysis failed for Join {self.obj.metaData.name}: {str(e)}") + raise e + + + + + def validate(self, + start_date: str | None = None, + end_date: str | None = None) -> None: + """ + Validate the Join object. + + Args: + start_date: Start date for validation (format: YYYYMMDD) + end_date: End date for validation (format: YYYYMMDD) + """ + start_date: str = start_date or self.default_start_date + end_date: str = end_date or self.default_end_date + + self.platform.log_operation(f"Validating Join {self.obj.metaData.name} from {start_date} to {end_date}") + + # Prepare Join for validation + join_to_validate: Join = copy.deepcopy(self.obj) + join_to_validate.left = self._update_source_dates(join_to_validate.left, start_date, end_date) + + # Update join parts sources + join_to_validate.joinParts = self._update_source_dates_for_join_parts( + join_to_validate.joinParts, start_date, end_date + ) + + # Start log capture just before executing JVM calls + log_token = self.platform.start_log_capture(f"Validate Join: {self.obj.metaData.name}") + + try: + # Convert to Java Join + java_join: JavaObject = self.join_to_java(join_to_validate, end_date) + # Validate Join + errors_list: JavaObject = self.jvm.ai.chronon.spark.PySparkUtils.validateJoin( + java_join, + start_date, + end_date, + self.platform.get_table_utils(), + self.platform.get_constants_provider() + ) + + self.platform.end_log_capture(log_token) + # Handle validation errors + self.platform.handle_validation_errors(errors_list, f"Join {self.obj.metaData.name}") + self.platform.log_operation(f"Validation for Join {self.obj.metaData.name} has completed") + except Exception as e: + self.platform.end_log_capture(log_token) + self.platform.log_operation(f"Validation failed for Join {self.obj.metaData.name}: {str(e)}") + raise e + + + + +class PlatformInterface(ABC): + """ + Interface for platform-specific operations. + + This class defines operations that vary by platform (Databricks, Jupyter, etc.) + and should be implemented by platform-specific classes. + """ + + def __init__(self, spark: SparkSession) -> None: + """ + Initialize with a SparkSession. + + Args: + spark: The SparkSession to use + """ + self.spark = spark + self.jvm = spark._jvm + self.java_spark_session = spark._jsparkSession + self.register_udfs() + + @abstractmethod + def get_constants_provider(self) -> JavaObject: + """ + Get the platform-specific constants provider. + + Returns: + A JavaObject representing the constants provider + """ + pass + + @abstractmethod + def get_table_utils(self) -> JavaObject: + """ + Get the platform-specific table utilities. + + Returns: + A JavaObject representing the table utilities + """ + pass + + @abstractmethod + def get_executable_join_cls(self) -> type[JoinExecutable]: + """ + Get the class for executing joins. + + Returns: + The class for executing joins + """ + pass + + @abstractmethod + def start_log_capture(self, job_name: str) -> Any: + """ + Start capturing logs for a job. + + Args: + job_name: The name of the job for log headers + + Returns: + A token representing the capture state (platform-specific) + """ + pass + + @abstractmethod + def end_log_capture(self, capture_token: Any) -> None: + """ + End log capturing and print the logs. + + Args: + capture_token: The token returned from start_log_capture + """ + pass + + def register_udfs(self) -> None: + """ + Register UDFs for the self.platform. + + This method is intentionally left empty but not abstract, as some platforms may not need to register UDFs. + + Subclasses can override this method to provide platform-specific UDF registration. + + Pro tip: Both the JVM Spark Session and Python Spark Session use the same spark-sql engine. You can register Python UDFS and use them in the JVM, as well as vice-versa. + At Stripe we currently only use Scala UDFs, so we include a JAR of our UDFs in the cluster and register them via: + + self.jvm.com.stripe.shepherd.udfs.ShepherdUdfRegistry.registerAllUdfs(self.java_spark_session) + """ + pass + + def log_operation(self, message: str) -> None: + """ + Log an operation. + + Args: + message: The message to log + """ + current_utc_time = datetime.utcnow() + time_str = current_utc_time.strftime('[%Y-%m-%d %H:%M:%S UTC]') + print(f'{time_str} {message}') + + def drop_table_if_exists(self, table_name: str) -> None: + """ + Drop a table if it exists. + + Args: + table_name: The name of the table to drop + """ + _ = self.spark.sql(f"DROP TABLE IF EXISTS {table_name}") + + + def handle_validation_errors(self, errors: JavaObject, object_name: str) -> None: + """ + Handle validation errors. + + Args: + errors: Platform-specific validation errors + object_name: Name of the object being validated + """ + if errors.length() > 0: + self.log_operation(message=f"Validation failed for {object_name} with the following errors:") + self.log_operation(message=str(errors)) + else: + self.log_operation(message=f"Validation passed for {object_name}.") + + def set_metadata(self, obj: GroupBy | Join | StagingQuery, mod_prefix: str, + name_prefix: str | None = None, output_namespace: str | None = None) -> T: + """ + Set the metadata for the object. + + Args: + obj: The object to set metadata on + mod_prefix: The root directory of where your features exist + name_prefix: Optional prefix to add to the name + output_namespace: Optional namespace to override the set output namespace + + Returns: + The updated object + """ + obj_type: type[GroupBy] | type[Join] = type(obj) + + # Handle object naming + if not obj.metaData.name: + try: + set_name(obj=obj, cls=obj_type, mod_prefix=mod_prefix) + except AttributeError: + raise AttributeError("Please provide a name when defining group_bys/joins/staging_queries adhoc.") + # We do this to avoid adding the name prefix multiple times + elif obj.metaData.name and name_prefix: + obj.metaData.name = obj.metaData.name.replace(f"{name_prefix}_", "") + + # Handle nested objects for GroupBy + if obj_type == GroupBy: + for s in obj.sources: + if s.joinSource and not s.joinSource.join.metaData.name: + set_name(obj=s.joinSource.join, cls=Join, mod_prefix=mod_prefix) + # Handle nested objects for Join + elif obj_type == Join: + for jp in obj.joinParts: + for s in jp.groupBy.sources: + if s.joinSource and not s.joinSource.join.metaData.name: + set_name(obj=s.joinSource.join, cls=Join, mod_prefix=mod_prefix) + + # Set output namespace + if output_namespace: + obj.metaData.outputNamespace = output_namespace + + if obj_type == Join: + for jp in obj.joinParts: + jp.groupBy.metaData.outputNamespace = output_namespace + + # Add user prefix to name + if name_prefix: + obj.metaData.name = f"{name_prefix}_{obj.metaData.name}" + + return obj + diff --git a/api/py/ai/chronon/utils.py b/api/py/ai/chronon/utils.py index 870231a458..0966d9a912 100644 --- a/api/py/ai/chronon/utils.py +++ b/api/py/ai/chronon/utils.py @@ -575,3 +575,33 @@ def get_config_path(join_name: str) -> str: assert "." in join_name, f"Invalid join name: {join_name}" team_name, config_name = join_name.split(".", 1) return f"production/joins/{team_name}/{config_name}" + + +def get_max_window_for_gb_in_days(group_by: api.GroupBy) -> int: + result: int = 1 + if group_by.aggregations: + for agg in group_by.aggregations: + for window in agg.windows: + if window.timeUnit == api.TimeUnit.MINUTES: + result = int( + max( + result, + ceil(window.length / 60 * 24), + ) + ) + elif window.timeUnit == api.TimeUnit.HOURS: + result = int( + max( + result, + ceil(window.length / 24), + ) + ) + elif window.timeUnit == api.TimeUnit.DAYS: + result = int( + max(result, window.length) + ) + else: + raise ValueError( + f"Unsupported time unit {window.timeUnit}. Please add logic above to handle the newly introduced time unit." + ) + return result \ No newline at end of file diff --git a/api/py/requirements/base.in b/api/py/requirements/base.in index 74b7ad51e5..167cc3818b 100644 --- a/api/py/requirements/base.in +++ b/api/py/requirements/base.in @@ -1,3 +1,4 @@ click thrift<0.14 sqlglot +pyspark==3.3.1 \ No newline at end of file diff --git a/api/py/requirements/base.txt b/api/py/requirements/base.txt index e5570ef610..7c2a8d8e20 100644 --- a/api/py/requirements/base.txt +++ b/api/py/requirements/base.txt @@ -1,4 +1,4 @@ -# SHA1:4478faec22832482c18fa37db7c49640e4dc015b +# SHA1:b59a9d5665a7082dea5b0bc9f4d1bc2ec737d9e8 # # This file is autogenerated by pip-compile-multi # To update, run: @@ -6,16 +6,14 @@ # pip-compile-multi # click==8.1.8 - # via -r requirements/base.in -importlib-metadata==6.7.0 - # via click + # via -r base.in +py4j==0.10.9.5 + # via pyspark +pyspark==3.3.1 + # via -r base.in six==1.17.0 # via thrift -sqlglot==26.8.0 - # via -r requirements/base.in +sqlglot==26.16.1 + # via -r base.in thrift==0.13.0 - # via -r requirements/base.in -typing-extensions==4.7.1 - # via importlib-metadata -zipp==3.15.0 - # via importlib-metadata + # via -r base.in diff --git a/spark/src/main/scala/ai/chronon/spark/PySparkUtils.scala b/spark/src/main/scala/ai/chronon/spark/PySparkUtils.scala new file mode 100644 index 0000000000..387e2c3300 --- /dev/null +++ b/spark/src/main/scala/ai/chronon/spark/PySparkUtils.scala @@ -0,0 +1,209 @@ +package ai.chronon.spark +import ai.chronon.aggregator.windowing.{FiveMinuteResolution, Resolution} +import ai.chronon.api +import ai.chronon.api.Extensions.MetadataOps +import ai.chronon.api.{ConstantNameProvider, Constants, ThriftJsonCodec} +import org.apache.spark.sql.DataFrame +import org.slf4j.LoggerFactory + +object PySparkUtils { + @transient lazy val logger = LoggerFactory.getLogger(getClass) + + /** + * Pyspark has a tough time creating a FiveMinuteResolution via jvm.ai.chronon.aggregator.windowing.FiveMinuteResolution so we provide this helper method + * @return FiveMinuteResolution + */ + + def getFiveMinuteResolution: Resolution = FiveMinuteResolution + + /** + * Creating optionals is difficult to support in Pyspark, so we provide this method as a work around + * @param timeRange a time range + * @return Empty time range optional + */ + def getTimeRangeOptional(timeRange: TimeRange): Option[TimeRange] = if (timeRange == null) Option.empty[TimeRange] else Option(timeRange) + + /** + * Creating optionals is difficult to support in Pyspark, so we provide this method as a work around + * @param str a string + * @return String optional + */ + def getStringOptional(str : String) : Option[String] = if (str == null) Option.empty[String] else Option(str) + + /** + * Creating optionals is difficult to support in Pyspark, so we provide this method as a work around + * Furthermore, ints can't be null in Scala so we need to pass the value in as a str + * @param strInt a string + * @return Int optional + */ + def getIntOptional(strInt : String) : Option[Int] = if (strInt == null) Option.empty[Int] else Option(strInt.toInt) + + /** + * Type parameters are difficult to support in Pyspark, so we provide these helper methods for ThriftJsonCodec.fromJsonStr + * @param groupByJson a JSON string representing a group by + * @return Chronon Scala API GroupBy object + */ + def parseGroupBy(groupByJson: String): api.GroupBy = { + ThriftJsonCodec.fromJsonStr[api.GroupBy](groupByJson, check = true, classOf[api.GroupBy]) + } + + /** + * Type parameters are difficult to support in Pyspark, so we provide these helper methods for ThriftJsonCodec.fromJsonStr + * @param joinJson a JSON string representing a join + * @return Chronon Scala API Join object + */ + def parseJoin(joinJson: String): api.Join = { + ThriftJsonCodec.fromJsonStr[api.Join](joinJson, check = true, classOf[api.Join]) + } + + /** + * Type parameters are difficult to support in Pyspark, so we provide these helper methods for ThriftJsonCodec.fromJsonStr + * @param sourceJson a JSON string representing a source. + * @return Chronon Scala API Source object + */ + def parseSource(sourceJson: String): api.Source = { + ThriftJsonCodec.fromJsonStr[api.Source](sourceJson, check = true, classOf[api.Source]) + } + + /** + * Helper function to get Temporal or Snapshot Accuracy + * + * @param getTemporal boolean value that will decide if we return temporal or snapshot accuracy . + * @return api.Accuracy + */ + def getAccuracy(getTemporal: Boolean): api.Accuracy = { + if (getTemporal) api.Accuracy.TEMPORAL else api.Accuracy.SNAPSHOT + } + + /** + * Helper function to allow a user to execute a Group By. + * + * @param groupByConf api.GroupBy Chronon scala GroupBy API object + * @param endDate str this represents the last date we will perform the aggregation for + * @param stepDays int this will determine how we chunk filling the missing partitions + * @param tableUtils TableUtils this will be used to perform ops against our data sources + * @param constantsProvider ConstantsProvider must be set from the Scala side. Doing so from PySpark will not properly set it. + * @return DataFrame + */ + def runGroupBy(groupByConf: api.GroupBy, endDate: String, stepDays: Option[Int], tableUtils: BaseTableUtils, constantsProvider: ConstantNameProvider) : DataFrame = { + logger.info(s"Executing GroupBy: ${groupByConf.metaData.name}") + Constants.initConstantNameProvider(constantsProvider) + GroupBy.computeBackfill( + groupByConf, + endDate, + tableUtils, + stepDays + ) + logger.info(s"Finished executing GroupBy: ${groupByConf.metaData.name}") + tableUtils.sql(s"SELECT * FROM ${groupByConf.metaData.outputTable}") + } + + /** + * Helper function to allow a user to execute a Join. + * + * @param joinConf api.Join Chronon scala Join API object + * @param endDate str this represents the last date we will perform the Join for + * @param stepDays int this will determine how we chunk filling the missing partitions + * @param tableUtils TableUtils this will be used to perform ops against our data sources + * @param constantsProvider ConstantsProvider must be set from the Scala side. Doing so from PySpark will not properly set it. + * @return DataFrame + */ + def runJoin(joinConf: api.Join, + endDate: String, + stepDays: Option[Int], + skipFirstHole: Boolean, + sampleNumOfRows: Option[Int], + tableUtils: BaseTableUtils, + constantsProvider: ConstantNameProvider + ) : DataFrame = { + logger.info(s"Executing Join ${joinConf.metaData.name}") + Constants.initConstantNameProvider(constantsProvider) + val join = new Join( + joinConf, + endDate, + tableUtils, + skipFirstHole = skipFirstHole + ) + val resultDf = join.computeJoin(stepDays) + logger.info(s"Finished executing Join ${joinConf.metaData.name}") + resultDf + } + + /** + * Helper function to validate a GroupBy + * + * @param groupByConf api.GroupBy Chronon scala GroupBy API object + * @param startDate start date for the group by + * @param endDate end date for the group by + * @param tableUtils TableUtils this will be used to perform ops against our data sources + * @param constantsProvider ConstantsProvider must be set from the Scala side. Doing so from PySpark will not properly set it. + * @return DataFrame + */ + def validateGroupBy(groupByConf: api.GroupBy, startDate: String, endDate: String, tableUtils: BaseTableUtils, constantsProvider: ConstantNameProvider) : List[String] = { + logger.info(s"Validating GroupBy ${groupByConf.metaData.name}") + Constants.initConstantNameProvider(constantsProvider) + val validator = new Validator(tableUtils, groupByConf, startDate, endDate) + val result = validator.validateGroupBy(groupByConf) + logger.info(s"Finished validating GroupBy ${groupByConf.metaData.name}") + result + } + + + /** + * Helper function to validate a Join + * + * @param joinConf api.Join Chronon scala Join API object + * @param startDate start date for the join + * @param endDate end date for the join + * @param tableUtils TableUtils this will be used to perform ops against our data sources + * @param constantsProvider ConstantsProvider must be set from the Scala side. Doing so from PySpark will not properly set it. + * @return DataFrame + */ + def validateJoin(joinConf: api.Join, startDate: String, endDate: String, tableUtils: BaseTableUtils, constantsProvider: ConstantNameProvider) : List[String] = { + logger.info(s"Validating Join: ${joinConf.metaData.name}") + Constants.initConstantNameProvider(constantsProvider) + val validator = new Validator(tableUtils, joinConf, startDate, endDate) + val result = validator.validateJoin(joinConf) + logger.info(s"Finished validating Join: ${joinConf.metaData.name}") + result + } + + /** + * Helper function to analyze a GroupBy + * + * @param groupByConf api.GroupBy Chronon scala GroupBy API object + * @param startDate start date for the group by + * @param endDate end date for the group by + * @param enableHitterAnalysis if true we will perform an analysis of what hot keys may be present + * @param tableUtils TableUtils this will be used to perform ops against our data sources + * @param constantsProvider ConstantsProvider must be set from the Scala side. Doing so from PySpark will not properly set it. + */ + def analyzeGroupBy(groupByConf: api.GroupBy, startDate: String, endDate: String, enableHitterAnalysis: Boolean, tableUtils: BaseTableUtils, constantsProvider: ConstantNameProvider) : Unit = { + logger.info(s"Analyzing GroupBy: ${groupByConf.metaData.name}") + Constants.initConstantNameProvider(constantsProvider) + val analyzer = new Analyzer(tableUtils, groupByConf, startDate, endDate, enableHitter = enableHitterAnalysis) + analyzer.analyzeGroupBy(groupByConf, enableHitter = enableHitterAnalysis) + logger.info(s"Finished analyzing GroupBy: ${groupByConf.metaData.name}") + } + + + /** + * Helper function to analyze a Join + * + * @param joinConf api.Join Chronon scala Join API object + * @param startDate start date for the join + * @param endDate end date for the join + * @param enableHitterAnalysis if true we will perform an analysis of what hot keys may be present + * @param tableUtils TableUtils this will be used to perform ops against our data sources + * @param constantsProvider ConstantsProvider must be set from the Scala side. Doing so from PySpark will not properly set it. + * @return DataFrame + */ + def analyzeJoin(joinConf: api.Join, startDate: String, endDate: String, enableHitterAnalysis: Boolean, tableUtils: BaseTableUtils, constantsProvider: ConstantNameProvider) : Unit = { + logger.info(s"Analyzing Join: ${joinConf.metaData.name}") + Constants.initConstantNameProvider(constantsProvider) + val analyzer = new Analyzer(tableUtils, joinConf, startDate, endDate, enableHitter = enableHitterAnalysis) + analyzer.analyzeJoin(joinConf, enableHitter = enableHitterAnalysis) + logger.info(s"Finished analyzing Join: ${joinConf.metaData.name}") + } + +} diff --git a/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksConstantsNameProvider.scala b/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksConstantsNameProvider.scala new file mode 100644 index 0000000000..6d3566e806 --- /dev/null +++ b/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksConstantsNameProvider.scala @@ -0,0 +1,12 @@ +package ai.chronon.spark.databricks + +import ai.chronon.api.Extensions.{WindowOps, WindowUtils} +import ai.chronon.api.{ConstantNameProvider, PartitionSpec} + +class DatabricksConstantsNameProvider extends ConstantNameProvider with Serializable { + override def TimeColumn: String = "_internal_time_column" + override def DatePartitionColumn: String = "day" + + override def HourPartitionColumn: String = "hr" + override def Partition: PartitionSpec = PartitionSpec(format = "yyyyMMdd", spanMillis = WindowUtils.Day.millis) +} \ No newline at end of file diff --git a/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala b/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala new file mode 100644 index 0000000000..f9d5059731 --- /dev/null +++ b/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala @@ -0,0 +1,117 @@ +package ai.chronon.spark.databricks + +import ai.chronon.api.Constants +import ai.chronon.spark.BaseTableUtils +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{Column, SparkSession} + +case class DatabricksTableUtils (override val sparkSession: SparkSession) extends BaseTableUtils { + + private def isTempView(tableName: String): Boolean = + sparkSession.sessionState.catalog.isTempView(tableName.split('.')) + + override def createTableSql(tableName: String, schema: StructType, partitionColumns: Seq[String], tableProperties: Map[String, String], fileFormat: String): String = { + + // Side effect: Creates the table in iceberg and then confirms the table creation was successfully by issuing a SHOW CREATE TABLE. + val partitionSparkCols: Seq[Column] = partitionColumns.map(col => org.apache.spark.sql.functions.col(col)) + + val emptyDf = sparkSession.createDataFrame(sparkSession.sparkContext.emptyRDD[org.apache.spark.sql.Row], schema) + + val createTableWriter = { + if(Option(tableProperties).exists(_.nonEmpty)) { + tableProperties + .foldLeft(emptyDf.writeTo(tableName).using("iceberg"))((createTableWriter, tblProperty) => createTableWriter.tableProperty(tblProperty._1, tblProperty._2)) + } + else { + emptyDf.writeTo(tableName).using("iceberg") + } + } + + + partitionSparkCols match { + case Seq() => + logger.info(s"Creating table $tableName without partitioning") + createTableWriter.create() + case Seq(head) => + logger.info(s"Creating table $tableName partitioned by $head") + createTableWriter.partitionedBy(head).create() + case head +: tail => + logger.info(s"Creating table $tableName partitioned by $head and $tail") + createTableWriter.partitionedBy(head, tail: _*).create() + } + + s"SHOW CREATE TABLE $tableName" // Confirm table creation in subsequent call + } + + override def getIcebergPartitions(tableName: String): Seq[String] = { + val partitionsDf = sparkSession.read.format("iceberg").load(s"$tableName.partitions") + val index = partitionsDf.schema.fieldIndex("partition") + if (partitionsDf.schema(index).dataType.asInstanceOf[StructType].fieldNames.contains("hr")) { + // Hour filter is currently buggy in iceberg. https://github.com/apache/iceberg/issues/4718 + // so we collect and then filter. + partitionsDf + .select(s"partition.${partitionColumn}", s"partition.${Constants.HourPartitionColumn}") + .collect() + .filter(_.get(1) == null) + .map(_.getString(0)) + .toSeq + } else if (partitionsDf.schema(index).dataType.asInstanceOf[StructType].fieldNames.contains("locality_zone")) { + partitionsDf + // TODO(FCOMP-2242) We should factor out a provider for getting Iceberg partitions + // so we can inject a Stripe-specific one that takes into account locality_zone + .select(s"partition.${partitionColumn}") + .where("partition.locality_zone == 'DEFAULT'") + .collect() + .map(_.getString(0)) + .toSeq + } + else { + partitionsDf + .select(s"partition.${partitionColumn}") + .collect() + .map(_.getString(0)) + .toSeq + } + } + + override def partitions( + tableName: String, + subPartitionsFilter: Map[String, String] = Map.empty, + partitionColumnOverride: String = Constants.PartitionColumn + ): Seq[String] = + // This is to support s3 prefix inputs. + if (isTempView(tableName)) { + if (subPartitionsFilter.nonEmpty) { + throw new NotImplementedError( + s"partitions cannot be called with tableName ${tableName} subPartitionsFilter ${subPartitionsFilter} because subPartitionsFilter is not supported on tempViews yet." + ) + } + // If the table is a temp view, fallback to querying for the distinct partition columns because SHOW PARTITIONS + // doesn't work on temp views. This can be inefficient for large tables. + logger.info( + s"Selecting partitions for temp view table tableName=$tableName, " + + s"partitionColumnOverride=$partitionColumnOverride." + ) + val outputPartitionsRowsDf = sql( + s"select distinct ${partitionColumnOverride} from $tableName" + ) + val outputPartitionsRows = outputPartitionsRowsDf.collect() + // Users have ran into issues where the partition column is not a string, so add logging to facilitate debug. + logger.info( + s"Found ${outputPartitionsRows.length} partitions for temp view table tableName=$tableName. The partition schema is ${outputPartitionsRowsDf.schema}." + ) + + val outputPartitionsRowsDistinct = outputPartitionsRows.map(_.getString(0)).distinct + logger.info( + s"Converted ${outputPartitionsRowsDistinct.length} distinct partitions for temp view table to String. " + + s"tableName=$tableName." + ) + + outputPartitionsRowsDistinct + } else { + super.partitions(tableName, subPartitionsFilter, partitionColumnOverride) + } + + + +} \ No newline at end of file From be70595711d666ce3b9fbfe78a989be5bcb85dc3 Mon Sep 17 00:00:00 2001 From: camweston-stripe <116691078+camweston-stripe@users.noreply.github.com> Date: Thu, 24 Apr 2025 11:50:32 -0700 Subject: [PATCH 02/31] Remove stripe specific comment --- .../ai/chronon/spark/databricks/DatabricksTableUtils.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala b/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala index f9d5059731..5efc9e0708 100644 --- a/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala @@ -57,8 +57,6 @@ case class DatabricksTableUtils (override val sparkSession: SparkSession) extend .toSeq } else if (partitionsDf.schema(index).dataType.asInstanceOf[StructType].fieldNames.contains("locality_zone")) { partitionsDf - // TODO(FCOMP-2242) We should factor out a provider for getting Iceberg partitions - // so we can inject a Stripe-specific one that takes into account locality_zone .select(s"partition.${partitionColumn}") .where("partition.locality_zone == 'DEFAULT'") .collect() From 8b0bd1a926562dbe4f77388f36eb545e3ec97b96 Mon Sep 17 00:00:00 2001 From: camweston-stripe <116691078+camweston-stripe@users.noreply.github.com> Date: Thu, 24 Apr 2025 12:28:41 -0700 Subject: [PATCH 03/31] Add base README file --- api/py/ai/chronon/pyspark/README.md | 328 ++++++++++++++++++++++++++++ 1 file changed, 328 insertions(+) create mode 100644 api/py/ai/chronon/pyspark/README.md diff --git a/api/py/ai/chronon/pyspark/README.md b/api/py/ai/chronon/pyspark/README.md new file mode 100644 index 0000000000..665d0bdc44 --- /dev/null +++ b/api/py/ai/chronon/pyspark/README.md @@ -0,0 +1,328 @@ +# Chronon Python Interface for PySpark Environments + +## Table of Contents +1. [Introduction](#introduction) +2. [Architecture Overview](#architecture-overview) +3. [Core Components](#core-components) +4. [Flow of Execution](#flow-of-execution) +5. [Extending the Framework](#extending-the-framework) +6. [Setup and Dependencies](#setup-and-dependencies) + +## Introduction + +The Chronon PySpark Interface provides a clean, object-oriented framework for executing Chronon feature definitions directly within a PySpark environment, like Databricks Notebooks. This interface streamlines the developer experience by removing the need to switch between multiple tools, allowing rapid prototyping and iteration of Chronon feature engineering workflows. + +This library enables users to: +- Run, analyze, and validate GroupBy and Join operations in a type-safe manner +- Execute feature computations within notebook environments like Databricks +- Implement platform-specific behavior while preserving a consistent interface +- Access JVM-based functionality directly from Python code + +## Architecture Overview + +### The Python-JVM Bridge + +At the core of this implementation is the interaction between Python and the Java Virtual Machine (JVM): + +``` +Python Environment | JVM Environment + | + ┌─────────────────────┐ | ┌─────────────────────┐ + │ Python Thrift Obj │ | │ Scala Thrift Obj │ + │ (GroupBy, Join) │─────┐ | ┌────▶│ (GroupBy, Join) │ + └─────────────────────┘ │ | │ └─────────────────────┘ + │ │ | │ │ + ▼ │ | │ ▼ + ┌─────────────────────┐ │ | │ ┌─────────────────────┐ + │ thrift_simple_json()│ │ | │ │ PySparkUtils.parse │ + └─────────────────────┘ │ | │ └─────────────────────┘ + │ │ | │ │ + ▼ │ | │ ▼ + ┌─────────────────────┐ │ | │ ┌─────────────────────┐ + │ JSON String │─────┼──────┼──────┼────▶│ Java Objects │ + └─────────────────────┘ │ | │ └─────────────────────┘ + │ | │ │ + │ | │ ▼ + ┌─────────────────────┐ │ | │ ┌─────────────────────┐ + │ PySparkExecutable │ │ | │ │ PySparkUtils.run │ + │ .run() │─────┼──────┼──────┼────▶│ GroupBy/Join │ + └─────────────────────┘ │ | │ └─────────────────────┘ + ▲ │ | │ │ + │ │ | │ ▼ + ┌─────────────────────┐ │ Py4J Socket │ ┌─────────────────────┐ + │ Python DataFrame │◀────┴──────┼──────┴─────│ JVM DataFrame │ + └─────────────────────┘ | └─────────────────────┘ + | + | + | + | + | + +``` + +- **Py4J**: Enables Python code to dynamically access Java objects, methods, and fields across the JVM boundary +- **PySpark**: Uses Py4J to communicate with the Spark JVM, translating Python calls into Spark's Java/Scala APIs +- **Thrift Objects**: Chronon features defined as Python thrift objects are converted to Java thrift objects for execution + +This design ensures that Python users can access the full power of Chronon's JVM-based computation engine all from a centralized Python environment. + +## Core Components + +The framework is built around several core abstractions: + +### PySparkExecutable + +An abstract base class that provides common functionality for executing Chronon features via PySpark: + +```python +class PySparkExecutable(Generic[T], ABC): + """ + Abstract base class defining common functionality for executing features via PySpark. + """ +``` + +- Handles initialization with object and SparkSession +- Provides utilities for updating dates in sources and queries +- Manages the execution of underlying join sources + +### Specialized Executables + +Two specialized interfaces extend the base executable for different Chronon types: + +- **GroupByExecutable**: Interface for executing GroupBy objects +- **JoinExecutable**: Interface for executing Join objects + +These interfaces define type-specific behaviors for running, analyzing, and validating features. + +### Platform Interface + +A key abstraction that enables platform-specific behavior: + +```python +class PlatformInterface(ABC): + """ + Interface for platform-specific operations. + """ +``` + +This interface defines operations that vary by platform (Databricks, Jupyter, etc.) and must be implemented by platform-specific classes. + +### Platform-Specific Implementations + +Concrete implementations for specific notebook environments: + +- **DatabricksPlatform**: Implements platform-specific operations for Databricks +- **DatabricksGroupBy**: Executes GroupBy objects in Databricks +- **DatabricksJoin**: Executes Join objects in Databricks + +``` +┌─────────────────────────┐ +│ PySparkExecutable │ +│ (Generic[T], ABC) │ +├─────────────────────────┤ +│ - obj: T │ +│ - spark: SparkSession │ +│ - jvm: JVMView │ +├─────────────────────────┤ +│ + get_platform() │ +│ # _update_query_dates() │ +│ # _update_source_dates()│ +│ # print_with_timestamp()│ +└───────────────┬─────────┘ + │ + ┌───────────┴────────────┐ + │ │ +┌───▼───────────────┐ ┌───▼───────────────┐ +│ GroupByExecutable│ │ JoinExecutable │ +├───────────────────┤ ├───────────────────┤ +│ │ │ │ +├───────────────────┤ ├───────────────────┤ +│ + run() │ │ + run() │ +│ + analyze() │ │ + analyze() │ +│ + validate() │ │ + validate() │ +└────────┬──────────┘ └────────┬──────────┘ + │ │ + │ │ +┌────────▼──────────┐ ┌────────▼──────────┐ +│ DatabricksGroupBy │ │ DatabricksJoin │ +├───────────────────┤ ├───────────────────┤ +│ │ │ │ +├───────────────────┤ ├───────────────────┤ +│ + get_platform() │ │ + get_platform() │ +└───────────────────┘ └───────────────────┘ + +┌─────────────────────────────┐ +│ PlatformInterface │ +│ (ABC) │ +├─────────────────────────────┤ +│ - spark: SparkSession │ +├─────────────────────────────┤ +│ + get_constants_provider() │ +│ + get_table_utils() │ +│ + register_udfs() │ +│ + get_executable_join_cls() │ +│ + start_log_capture() │ +│ + end_log_capture() │ +│ + log_operation() │ +│ + drop_table_if_exists() │ +│ + handle_validation_errors()│ +└───────────┬─────────────────┘ + │ + │ +┌───────────▼────────────────┐ +│ DatabricksPlatform │ +├────────────────────────────┤ +│ - dbutils: DBUtils │ +│ - constants_provider │ +│ - table_utils │ +├────────────────────────────┤ +│ + get_constants_provider() │ +│ + get_table_utils() │ +│ + register_udfs() │ +│ + get_executable_join_cls()│ +│ + start_log_capture() │ +│ + end_log_capture() │ +│ + get_databricks_user() │ +└────────────────────────────┘ +``` + +## Flow of Execution + +When a user calls a method like `DatabricksGroupBy(group_by, py_spark_session).run()`, the following sequence occurs: + +1. **Object Preparation**: + - The Python thrift object (GroupBy, Join) is copied and updated with appropriate dates (This interface is meant to be used to run prototypes over smaller date ranges and not full backfills) + - Underlying join sources are executed if needed + +2. **JVM Conversion**: + - The Python thrift object is converted to JSON representation + - The JSON is parsed into a Java thrift object on the JVM side via Py4J + +3. **Execution**: + - The JVM executes the computation using Spark + - Results are captured in a Spark DataFrame on the JVM side + +4. **Result Return**: + - The JVM DataFrame is wrapped in a Python DataFrame object + - The Python DataFrame is returned to the user + +5. **Log Handling**: + - Throughout the process, JVM logs are captured + - Logs are printed in the notebook upon completion or errors + +## Extending the Framework + +### Implementing a New Platform Interface + +To add support for a new notebook environment (e.g., Jupyter), follow these steps: + +1. **Create a new platform implementation**: + +```python +class JupyterPlatform(PlatformInterface): + def __init__(self, spark: SparkSession): + super().__init__(spark) + # Initialize Jupyter-specific components + + @override + def get_constants_provider(self) -> JavaObject: + # Return Jupyter-specific constants provider + pass + + @override + def get_table_utils(self) -> JavaObject: + # Return Jupyter-specific table utilities + pass + + @override + def register_udfs(self) -> None: + # Register any necessary UDFs for Jupyter + # Recall that UDFs are registered to the shared spark-sql engine + # So you can register python and or scala udfs and use them on both spark sessions + pass + + @override + def get_executable_join_cls(self) -> type[JoinExecutable]: + # Return the Jupyter-specific join executable class + return JupyterJoin + + @override + def start_log_capture(self, job_name: str) -> Any: + # Start capturing logs in Jupyter + pass + + @override + def end_log_capture(self, capture_token: Any) -> None: + # End log capturing and print the logs in Jupyter + pass +``` + +2. **Create concrete executable implementations**: + +```python +class JupyterGroupBy(GroupByExecutable): + def __init__(self, group_by: GroupBy, spark_session: SparkSession): + super().__init__(group_by, spark_session) + # Set metadata as needed + self.obj: GroupBy = self.platform.set_metadata(obj=self.obj) + + @override + def get_platform(self) -> PlatformInterface: + return JupyterPlatform(self.spark) + +class JupyterJoin(JoinExecutable): + def __init__(self, join: Join, spark_session: SparkSession): + super().__init__(join, spark_session) + # Set metadata as needed + self.obj: Join = self.platform.set_metadata(obj=self.obj) + + @override + def get_platform(self) -> PlatformInterface: + return JupyterPlatform(self.spark) +``` + +### Key Methods to Override + +When implementing a platform interface, pay special attention to these methods: + +- **get_constants_provider()**: Return a platform-specific implementation of constants +- **get_table_utils()**: Return platform-specific table utilities +- **start_log_capture()** and **end_log_capture()**: Implement platform-specific log capturing +- **handle_validation_errors()**: Implement platform-specific error handling + +## Setup and Dependencies + +### Requirements + +1. **Spark Dependencies**: The Chronon Java/Scala JARs must be included in your Spark cluster: + ```python + spark.conf.set("spark.jars", "/path/to/chronon-jars.jar") + ``` + +2. **Python Dependencies**: + - pyspark (tested on both 3.1 and 3.3) + - py4j + - chronon_ai (The Python package for Chronon) + +### Example Setup + +Here's a minimal example of setting up and using the Chronon Python interface in a Databricks notebook: + +```python +# Import the required modules +from pyspark.sql import SparkSession +from ai.chronon.pyspark.databricks import DatabricksGroupBy, DatabricksJoin +from ai.chronon.api.ttypes import GroupBy, Join +from ai.chronon.group_by import Aggregation, Operation, Window, TimeUnit + +# Define your GroupBy or Join object +my_group_by = GroupBy(...) + +# Create an executable +executable = DatabricksGroupBy(my_group_by, spark) + +# Run the executable +result_df = executable.run(start_date='20250101', end_date='20250107') +``` + +--- \ No newline at end of file From a24bc2fa6be8dc27573599292c8c63b2436520dd Mon Sep 17 00:00:00 2001 From: camweston-stripe <116691078+camweston-stripe@users.noreply.github.com> Date: Thu, 24 Apr 2025 12:55:53 -0700 Subject: [PATCH 04/31] Remvoe s3 logic --- api/py/ai/chronon/pyspark/executables.py | 30 ++++++++++-------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/api/py/ai/chronon/pyspark/executables.py b/api/py/ai/chronon/pyspark/executables.py index f0469ae893..107c04bad7 100644 --- a/api/py/ai/chronon/pyspark/executables.py +++ b/api/py/ai/chronon/pyspark/executables.py @@ -10,7 +10,7 @@ from py4j.java_gateway import JavaObject from ai.chronon.api.ttypes import GroupBy, Join, JoinPart, JoinSource, Source, Query, StagingQuery -from ai.chronon.utils import set_name, get_max_window_for_gb_in_days, output_table_name +from ai.chronon.utils import __set_name as set_name, get_max_window_for_gb_in_days, output_table_name from ai.chronon.repo.serializer import thrift_simple_json from ai.chronon.pyspark.constants import PARTITION_COLUMN_FORMAT @@ -163,9 +163,9 @@ def print_with_timestamp(self, message: str) -> None: print(f'{time_str} {message}') - def group_by_to_java(self, group_by: GroupBy, end_date: str) -> JavaObject: + def group_by_to_java(self, group_by: GroupBy) -> JavaObject: """ - Convert GroupBy object to Java representation with updated S3 prefixes. + Convert GroupBy object to Java representation. Args: group_by: The GroupBy object to convert @@ -176,13 +176,11 @@ def group_by_to_java(self, group_by: GroupBy, end_date: str) -> JavaObject: """ json_representation: str = thrift_simple_json(group_by) java_group_by: JavaObject = self.jvm.ai.chronon.spark.PySparkUtils.parseGroupBy(json_representation) - return self.jvm.ai.chronon.spark.S3Utils.readAndUpdateS3PrefixesForGroupBy( - java_group_by, end_date, self.java_spark_session - ) + return java_group_by - def join_to_java(self, join: Join, end_date: str) -> JavaObject: + def join_to_java(self, join: Join) -> JavaObject: """ - Convert Join object to Java representation with updated S3 prefixes. + Convert Join object to Java representation. Args: join: The Join object to convert @@ -193,9 +191,7 @@ def join_to_java(self, join: Join, end_date: str) -> JavaObject: """ json_representation: str = thrift_simple_json(join) java_join: JavaObject = self.jvm.ai.chronon.spark.PySparkUtils.parseJoin(json_representation) - return self.jvm.ai.chronon.spark.S3Utils.readAndUpdateS3PrefixesForJoin( - java_join, end_date, self.java_spark_session - ) + return java_join class GroupByExecutable(PySparkExecutable[GroupBy], ABC): @@ -266,7 +262,7 @@ def run(self, try: # Convert to Java GroupBy - java_group_by = self.group_by_to_java(group_by_to_execute, end_date) + java_group_by = self.group_by_to_java(group_by_to_execute) # Execute GroupBy result_df_scala: JavaObject = self.jvm.ai.chronon.spark.PySparkUtils.runGroupBy( java_group_by, @@ -314,7 +310,7 @@ def analyze(self, try: # Convert to Java GroupBy - java_group_by = self.group_by_to_java(group_by_to_analyze, end_date) + java_group_by = self.group_by_to_java(group_by_to_analyze) # Analyze GroupBy self.jvm.ai.chronon.spark.PySparkUtils.analyzeGroupBy( java_group_by, @@ -359,7 +355,7 @@ def validate(self, try: # Convert to Java GroupBy - java_group_by: JavaObject = self.group_by_to_java(group_by_to_validate, end_date) + java_group_by: JavaObject = self.group_by_to_java(group_by_to_validate) # Validate GroupBy errors_list: JavaObject = self.jvm.ai.chronon.spark.PySparkUtils.validateGroupBy( java_group_by, @@ -456,7 +452,7 @@ def run(self, try: # Convert to Java Join - java_join: JavaObject = self.join_to_java(join_to_execute, end_date) + java_join: JavaObject = self.join_to_java(join_to_execute) # Execute Join result_df_scala: JavaObject = self.jvm.ai.chronon.spark.PySparkUtils.runJoin( java_join, @@ -509,7 +505,7 @@ def analyze(self, try: # Convert to Java Join - java_join: JavaObject = self.join_to_java(join_to_analyze, end_date) + java_join: JavaObject = self.join_to_java(join_to_analyze) # Analyze Join self.jvm.ai.chronon.spark.PySparkUtils.analyzeJoin( java_join, @@ -559,7 +555,7 @@ def validate(self, try: # Convert to Java Join - java_join: JavaObject = self.join_to_java(join_to_validate, end_date) + java_join: JavaObject = self.join_to_java(join_to_validate) # Validate Join errors_list: JavaObject = self.jvm.ai.chronon.spark.PySparkUtils.validateJoin( java_join, From 959d3f090f276721542858bb7848af653588d6f0 Mon Sep 17 00:00:00 2001 From: camweston-stripe <116691078+camweston-stripe@users.noreply.github.com> Date: Thu, 24 Apr 2025 12:58:54 -0700 Subject: [PATCH 05/31] Update readme --- api/py/ai/chronon/pyspark/README.md | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/api/py/ai/chronon/pyspark/README.md b/api/py/ai/chronon/pyspark/README.md index 665d0bdc44..147e3e1553 100644 --- a/api/py/ai/chronon/pyspark/README.md +++ b/api/py/ai/chronon/pyspark/README.md @@ -263,7 +263,11 @@ class JupyterPlatform(PlatformInterface): class JupyterGroupBy(GroupByExecutable): def __init__(self, group_by: GroupBy, spark_session: SparkSession): super().__init__(group_by, spark_session) - # Set metadata as needed + # You can pass Jupyter specific parameters into to set metadata + # that allow you to customize things like: + # - What namespace is written to + # - Table name prefixing (in the Databricks implementation we prefix the table name with the notebook username) + # - Root dir for where your existing feature defs are if you want to import features that were defined in an IDE self.obj: GroupBy = self.platform.set_metadata(obj=self.obj) @override @@ -273,7 +277,11 @@ class JupyterGroupBy(GroupByExecutable): class JupyterJoin(JoinExecutable): def __init__(self, join: Join, spark_session: SparkSession): super().__init__(join, spark_session) - # Set metadata as needed + # You can pass Jupyter specific parameters into to set metadata + # that allow you to customize things like: + # - What namespace is written to + # - Table name prefixing (in the Databricks implementation we prefix the table name with the notebook username) + # - Root dir for where your existing feature defs are if you want to import features that were defined in an IDE self.obj: Join = self.platform.set_metadata(obj=self.obj) @override From 295cc2b73d7a3a2cf51e785ffae21f3389715b56 Mon Sep 17 00:00:00 2001 From: camweston-stripe <116691078+camweston-stripe@users.noreply.github.com> Date: Thu, 24 Apr 2025 13:00:07 -0700 Subject: [PATCH 06/31] Update readme --- api/py/ai/chronon/pyspark/README.md | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/api/py/ai/chronon/pyspark/README.md b/api/py/ai/chronon/pyspark/README.md index 147e3e1553..b2d4a94623 100644 --- a/api/py/ai/chronon/pyspark/README.md +++ b/api/py/ai/chronon/pyspark/README.md @@ -302,19 +302,14 @@ When implementing a platform interface, pay special attention to these methods: ### Requirements -1. **Spark Dependencies**: The Chronon Java/Scala JARs must be included in your Spark cluster: - ```python - spark.conf.set("spark.jars", "/path/to/chronon-jars.jar") - ``` +1. **Spark Dependencies**: The Chronon Java/Scala JARs must be included in your Spark cluster 2. **Python Dependencies**: - pyspark (tested on both 3.1 and 3.3) - - py4j - - chronon_ai (The Python package for Chronon) ### Example Setup -Here's a minimal example of setting up and using the Chronon Python interface in a Databricks notebook: +Here's a minimal example of setting up and using the Chronon Python interface in a Databricks notebook. It assumes that you have already included the necessary jars in your cluster dependencies. ```python # Import the required modules From 848ba9fad12dd02021e8fb5aa6c253f63d9df11d Mon Sep 17 00:00:00 2001 From: camweston-stripe <116691078+camweston-stripe@users.noreply.github.com> Date: Thu, 24 Apr 2025 13:04:36 -0700 Subject: [PATCH 07/31] Fix imports in utils.py --- api/py/ai/chronon/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/py/ai/chronon/utils.py b/api/py/ai/chronon/utils.py index 0966d9a912..99741862ca 100644 --- a/api/py/ai/chronon/utils.py +++ b/api/py/ai/chronon/utils.py @@ -20,6 +20,7 @@ import shutil import subprocess import tempfile +from math import ceil from collections.abc import Iterable from dataclasses import dataclass, fields from enum import Enum From e45de5fa459175e60fdc555d3004f86725a37cc2 Mon Sep 17 00:00:00 2001 From: camweston-stripe <116691078+camweston-stripe@users.noreply.github.com> Date: Thu, 24 Apr 2025 13:05:10 -0700 Subject: [PATCH 08/31] Fix comments in constants --- api/py/ai/chronon/pyspark/constants.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/api/py/ai/chronon/pyspark/constants.py b/api/py/ai/chronon/pyspark/constants.py index 5df155bfec..5b2cc3955d 100644 --- a/api/py/ai/chronon/pyspark/constants.py +++ b/api/py/ai/chronon/pyspark/constants.py @@ -1,18 +1,18 @@ from __future__ import annotations -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # Company Specific Constants -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- PARTITION_COLUMN_FORMAT: str = '%Y%m%d' -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # Platform Specific Constants -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # Databricks Constants -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- DATABRICKS_OUTPUT_NAMESPACE: str = 'chronon_poc_usertables' DATABRICKS_JVM_LOG_FILE: str = "/databricks/chronon_logfile.log" DATABRICKS_ROOT_DIR_FOR_IMPORTED_FEATURES: str = "src" \ No newline at end of file From 7357519cf0d13797b1d51027dcc1c505b8cf3696 Mon Sep 17 00:00:00 2001 From: camweston-stripe <116691078+camweston-stripe@users.noreply.github.com> Date: Thu, 24 Apr 2025 13:05:43 -0700 Subject: [PATCH 09/31] add new line to end of constants.py --- api/py/ai/chronon/pyspark/constants.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/py/ai/chronon/pyspark/constants.py b/api/py/ai/chronon/pyspark/constants.py index 5b2cc3955d..9a25d54593 100644 --- a/api/py/ai/chronon/pyspark/constants.py +++ b/api/py/ai/chronon/pyspark/constants.py @@ -15,4 +15,4 @@ # -------------------------------------------------------------------------- DATABRICKS_OUTPUT_NAMESPACE: str = 'chronon_poc_usertables' DATABRICKS_JVM_LOG_FILE: str = "/databricks/chronon_logfile.log" -DATABRICKS_ROOT_DIR_FOR_IMPORTED_FEATURES: str = "src" \ No newline at end of file +DATABRICKS_ROOT_DIR_FOR_IMPORTED_FEATURES: str = "src" From 380b928a586f791894cfa5816b896bc6eafe6dd8 Mon Sep 17 00:00:00 2001 From: camweston-stripe <116691078+camweston-stripe@users.noreply.github.com> Date: Thu, 24 Apr 2025 13:07:36 -0700 Subject: [PATCH 10/31] Change output namespace --- api/py/ai/chronon/pyspark/constants.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/api/py/ai/chronon/pyspark/constants.py b/api/py/ai/chronon/pyspark/constants.py index 9a25d54593..b9f27d41c6 100644 --- a/api/py/ai/chronon/pyspark/constants.py +++ b/api/py/ai/chronon/pyspark/constants.py @@ -1,4 +1,5 @@ from __future__ import annotations +from typing import Optional # -------------------------------------------------------------------------- # Company Specific Constants @@ -13,6 +14,6 @@ # -------------------------------------------------------------------------- # Databricks Constants # -------------------------------------------------------------------------- -DATABRICKS_OUTPUT_NAMESPACE: str = 'chronon_poc_usertables' +DATABRICKS_OUTPUT_NAMESPACE: Optional[str] = None DATABRICKS_JVM_LOG_FILE: str = "/databricks/chronon_logfile.log" DATABRICKS_ROOT_DIR_FOR_IMPORTED_FEATURES: str = "src" From e45f9dcb9709181b7eedb0ababfab52fe340dab9 Mon Sep 17 00:00:00 2001 From: camweston-stripe <116691078+camweston-stripe@users.noreply.github.com> Date: Thu, 24 Apr 2025 13:09:34 -0700 Subject: [PATCH 11/31] Add setup step for log file --- api/py/ai/chronon/pyspark/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/api/py/ai/chronon/pyspark/README.md b/api/py/ai/chronon/pyspark/README.md index b2d4a94623..804d3772e8 100644 --- a/api/py/ai/chronon/pyspark/README.md +++ b/api/py/ai/chronon/pyspark/README.md @@ -307,6 +307,8 @@ When implementing a platform interface, pay special attention to these methods: 2. **Python Dependencies**: - pyspark (tested on both 3.1 and 3.3) +3. **Log File**: You will need to make sure that your Chronon JVM logs are writting to single file. This is generally platform specific. + ### Example Setup Here's a minimal example of setting up and using the Chronon Python interface in a Databricks notebook. It assumes that you have already included the necessary jars in your cluster dependencies. From 9ef39558b6b95503ea38e1a4ca636faebd847c7e Mon Sep 17 00:00:00 2001 From: camweston-stripe <116691078+camweston-stripe@users.noreply.github.com> Date: Thu, 24 Apr 2025 13:27:13 -0700 Subject: [PATCH 12/31] Reformat databricks.py to be flake8 compliant --- api/py/ai/chronon/pyspark/databricks.py | 38 ++++++++++++++----------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/api/py/ai/chronon/pyspark/databricks.py b/api/py/ai/chronon/pyspark/databricks.py index c5cad0ef2c..4d06830365 100644 --- a/api/py/ai/chronon/pyspark/databricks.py +++ b/api/py/ai/chronon/pyspark/databricks.py @@ -1,33 +1,31 @@ from __future__ import annotations import os -from typing_extensions import override from typing import cast -from pyspark.sql import SparkSession -from pyspark.dbutils import DBUtils + from py4j.java_gateway import JavaObject +from pyspark.dbutils import DBUtils +from pyspark.sql import SparkSession +from typing_extensions import override +from ai.chronon.api.ttypes import GroupBy, Join +from ai.chronon.pyspark.constants import ( + DATABRICKS_JVM_LOG_FILE, + DATABRICKS_OUTPUT_NAMESPACE, + DATABRICKS_ROOT_DIR_FOR_IMPORTED_FEATURES, +) from ai.chronon.pyspark.executables import ( GroupByExecutable, JoinExecutable, PlatformInterface, ) -from ai.chronon.api.ttypes import GroupBy, Join - -from ai.chronon.pyspark.constants import ( - DATABRICKS_OUTPUT_NAMESPACE, - DATABRICKS_JVM_LOG_FILE, - DATABRICKS_ROOT_DIR_FOR_IMPORTED_FEATURES, -) class DatabricksPlatform(PlatformInterface): """ Databricks-specific implementation of the platform interface. """ - - def __init__(self, spark: SparkSession): """ Initialize Databricks-specific components. @@ -49,8 +47,12 @@ def get_constants_provider(self) -> JavaObject: Returns: A JavaObject representing the constants provider """ - constants_provider: JavaObject = self.jvm.ai.chronon.spark.databricks.DatabricksConstantsNameProvider() - self.jvm.ai.chronon.api.Constants.initConstantNameProvider(constants_provider) + constants_provider: JavaObject = ( + self.jvm.ai.chronon.spark.databricks.DatabricksConstantsNameProvider() + ) + self.jvm.ai.chronon.api.Constants.initConstantNameProvider( + constants_provider + ) return constants_provider @override @@ -61,7 +63,9 @@ def get_table_utils(self) -> JavaObject: Returns: A JavaObject representing the table utilities """ - return self.jvm.ai.chronon.spark.databricks.DatabricksTableUtils(self.java_spark_session) + return self.jvm.ai.chronon.spark.databricks.DatabricksTableUtils( + self.java_spark_session + ) @override def register_udfs(self) -> None: @@ -102,7 +106,6 @@ def end_log_capture(self, capture_token: tuple[int, str]) -> None: print(file_handler.read()) print("*" * 10, f" END LOGS FOR {job_name} ", "*" * 10, "\n\n") - def get_databricks_user(self) -> str: """ Get the current Databricks user. @@ -110,7 +113,8 @@ def get_databricks_user(self) -> str: Returns: The username of the current Databricks user """ - user_email = self.dbutils.notebook.entry_point.getDbutils().notebook().getContext().userName().get() + user_email = self.dbutils.notebook.entry_point.getDbutils().notebook( + ).getContext().userName().get() return user_email.split('@')[0].lower() From 6f80b89ce701e2783d17429223ae0db13374fd13 Mon Sep 17 00:00:00 2001 From: camweston-stripe <116691078+camweston-stripe@users.noreply.github.com> Date: Thu, 24 Apr 2025 17:37:39 -0700 Subject: [PATCH 13/31] Reformat executables to be flake8 compliant --- api/py/ai/chronon/pyspark/executables.py | 443 ++++++++++++++++------- 1 file changed, 303 insertions(+), 140 deletions(-) diff --git a/api/py/ai/chronon/pyspark/executables.py b/api/py/ai/chronon/pyspark/executables.py index 107c04bad7..b12ed498d3 100644 --- a/api/py/ai/chronon/pyspark/executables.py +++ b/api/py/ai/chronon/pyspark/executables.py @@ -1,18 +1,21 @@ from __future__ import annotations -from abc import ABC, abstractmethod import copy +from abc import ABC, abstractmethod from datetime import datetime, timedelta -from typing import TypeVar, Generic, cast, Any -from py4j.java_gateway import JavaObject, JVMView +from typing import Any, Generic, TypeVar, cast +from py4j.java_gateway import JVMView, JavaObject from pyspark.sql import DataFrame, SparkSession -from py4j.java_gateway import JavaObject -from ai.chronon.api.ttypes import GroupBy, Join, JoinPart, JoinSource, Source, Query, StagingQuery -from ai.chronon.utils import __set_name as set_name, get_max_window_for_gb_in_days, output_table_name -from ai.chronon.repo.serializer import thrift_simple_json +from ai.chronon.api.ttypes import ( + GroupBy, Join, JoinPart, JoinSource, Query, Source, StagingQuery +) from ai.chronon.pyspark.constants import PARTITION_COLUMN_FORMAT +from ai.chronon.repo.serializer import thrift_simple_json +from ai.chronon.utils import ( + __set_name as set_name, get_max_window_for_gb_in_days, output_table_name +) # Define type variable for our executables T = TypeVar('T', GroupBy, Join) @@ -38,8 +41,12 @@ def __init__(self, obj: T, spark_session: SparkSession): self.jvm: JVMView = self.spark._jvm self.platform: PlatformInterface = self.get_platform() self.java_spark_session: JavaObject = self.spark._jsparkSession - self.default_start_date: str = (datetime.now() - timedelta(days=8)).strftime(PARTITION_COLUMN_FORMAT) - self.default_end_date: str = (datetime.now() - timedelta(days=1)).strftime(PARTITION_COLUMN_FORMAT) + self.default_start_date: str = ( + datetime.now() - timedelta(days=8) + ).strftime(PARTITION_COLUMN_FORMAT) + self.default_end_date: str = ( + datetime.now() - timedelta(days=1) + ).strftime(PARTITION_COLUMN_FORMAT) @abstractmethod def get_platform(self) -> PlatformInterface: @@ -51,8 +58,9 @@ def get_platform(self) -> PlatformInterface: """ pass - - def _update_query_dates(self, query: Query, start_date: str, end_date: str) -> Query: + def _update_query_dates( + self, query: Query, start_date: str, end_date: str + ) -> Query: """ Update start and end dates of a query. @@ -69,7 +77,9 @@ def _update_query_dates(self, query: Query, start_date: str, end_date: str) -> Q query_copy.endPartition = end_date return query_copy - def _update_source_dates(self, source: Source, start_date: str, end_date: str) -> Source: + def _update_source_dates( + self, source: Source, start_date: str, end_date: str + ) -> Source: """ Update start and end dates of a source. @@ -90,7 +100,9 @@ def _update_source_dates(self, source: Source, start_date: str, end_date: str) - cast(Query, source_copy.entities.query), start_date, end_date) return source_copy - def _execute_underlying_join_sources(self, group_bys: list[GroupBy], start_date: str, end_date: str, step_days: int) -> None: + def _execute_underlying_join_sources( + self, group_bys: list[GroupBy], start_date: str, end_date: str, step_days: int + ) -> None: """ Execute underlying join sources. @@ -100,7 +112,6 @@ def _execute_underlying_join_sources(self, group_bys: list[GroupBy], start_date: end_date: End date for execution step_days: Number of days to process in each step """ - joins_to_execute: list[Join] = [] join_sources_to_execute_start_dates: dict[str, str] = {} @@ -113,46 +124,71 @@ def _execute_underlying_join_sources(self, group_bys: list[GroupBy], start_date: if not group_by_join_sources: continue - - # Recall that records generated by the inner join are input events for the outer join - # Therefore in order to correctly aggregate the outer join, your inner join needs to be run from start_date - max_window_for_gb_in_days - max_window_for_gb_in_days: int = get_max_window_for_gb_in_days(group_by) + # Recall that records generated by the inner join are input events + # for the outer join. Therefore in order to correctly aggregate the + # outer join, your inner join needs to be run from + # start_date - max_window_for_gb_in_days + max_window_for_gb_in_days: int = get_max_window_for_gb_in_days( + group_by) shifted_start_date = ( - datetime.strptime(start_date, PARTITION_COLUMN_FORMAT) - - timedelta(days=max_window_for_gb_in_days) + datetime.strptime(start_date, PARTITION_COLUMN_FORMAT) - + timedelta(days=max_window_for_gb_in_days) ).strftime(PARTITION_COLUMN_FORMAT) for js in group_by_join_sources: js_name: str | None = js.join.metaData.name if js_name is None: - raise ValueError(f"Join source {js} does not have a name. Was set_metadata called?") + raise ValueError( + f"Join source {js} does not have a name. " + "Was set_metadata called?" + ) if js_name not in join_sources_to_execute_start_dates: join_sources_to_execute_start_dates[js_name] = shifted_start_date joins_to_execute.append(js.join) else: - existing_start_date: str = join_sources_to_execute_start_dates[js_name] - join_sources_to_execute_start_dates[js_name] = min(shifted_start_date, existing_start_date) + existing_start_date: str = join_sources_to_execute_start_dates[ + js_name + ] + join_sources_to_execute_start_dates[js_name] = min( + shifted_start_date, existing_start_date + ) if not joins_to_execute: return - self.platform.log_operation(f"Executing {len(joins_to_execute)} Join Sources") + self.platform.log_operation( + f"Executing {len(joins_to_execute)} Join Sources" + ) for join in joins_to_execute: - j_start_date: str = join_sources_to_execute_start_dates[join.metaData.name] + j_start_date: str = join_sources_to_execute_start_dates[ + join.metaData.name + ] - executable_join_cls: type[JoinExecutable] = self.platform.get_executable_join_cls() + executable_join_cls: type[JoinExecutable] = ( + self.platform.get_executable_join_cls() + ) executable_join = executable_join_cls(join, self.spark) - self.platform.log_operation(f"Executing Join Source {join.metaData.name} from {j_start_date} to {end_date}") - _ = executable_join.run(start_date=j_start_date, end_date=end_date, step_days=step_days) + self.platform.log_operation( + f"Executing Join Source {join.metaData.name} " + f"from {j_start_date} to {end_date}" + ) + _ = executable_join.run( + start_date=j_start_date, end_date=end_date, step_days=step_days + ) - output_table_name_for_js: str = output_table_name(join, full_name=True) - self.platform.log_operation(f"Join Source {join.metaData.name} will be read from {output_table_name_for_js}") + output_table_name_for_js: str = output_table_name( + join, full_name=True + ) + self.platform.log_operation( + f"Join Source {join.metaData.name} will be read from " + f"{output_table_name_for_js}" + ) self.platform.log_operation("Finished executing Join Sources") @@ -162,20 +198,20 @@ def print_with_timestamp(self, message: str) -> None: time_str = current_utc_time.strftime('[%Y-%m-%d %H:%M:%S UTC]') print(f'{time_str} {message}') - def group_by_to_java(self, group_by: GroupBy) -> JavaObject: """ Convert GroupBy object to Java representation. Args: group_by: The GroupBy object to convert - end_date: End date for execution Returns: Java representation of the GroupBy """ json_representation: str = thrift_simple_json(group_by) - java_group_by: JavaObject = self.jvm.ai.chronon.spark.PySparkUtils.parseGroupBy(json_representation) + java_group_by: JavaObject = self.jvm.ai.chronon.spark.PySparkUtils.parseGroupBy( + json_representation + ) return java_group_by def join_to_java(self, join: Join) -> JavaObject: @@ -184,20 +220,23 @@ def join_to_java(self, join: Join) -> JavaObject: Args: join: The Join object to convert - end_date: End date for execution Returns: Java representation of the Join """ json_representation: str = thrift_simple_json(join) - java_join: JavaObject = self.jvm.ai.chronon.spark.PySparkUtils.parseJoin(json_representation) + java_join: JavaObject = self.jvm.ai.chronon.spark.PySparkUtils.parseJoin( + json_representation + ) return java_join class GroupByExecutable(PySparkExecutable[GroupBy], ABC): """Interface for executing GroupBy objects""" - def _update_source_dates_for_group_by(self, group_by: GroupBy, start_date: str, end_date: str) -> GroupBy: + def _update_source_dates_for_group_by( + self, group_by: GroupBy, start_date: str, end_date: str + ) -> GroupBy: """ Update start and end dates of sources in GroupBy. @@ -213,14 +252,18 @@ def _update_source_dates_for_group_by(self, group_by: GroupBy, start_date: str, return group_by for i, source in enumerate(group_by.sources): - group_by.sources[i] = self._update_source_dates(source, start_date, end_date) + group_by.sources[i] = self._update_source_dates( + source, start_date, end_date + ) return group_by - def run(self, - start_date: str | None = None, - end_date: str | None = None, - step_days: int = 30, - skip_execution_of_underlying_join: bool = False) -> DataFrame: + def run( + self, + start_date: str | None = None, + end_date: str | None = None, + step_days: int = 30, + skip_execution_of_underlying_join: bool = False + ) -> DataFrame: """ Execute the GroupBy object. @@ -228,63 +271,91 @@ def run(self, start_date: Start date for the execution (format: YYYYMMDD) end_date: End date for the execution (format: YYYYMMDD) step_days: Number of days to process in each step - skip_execution_of_underlying_join: Whether to skip execution of underlying joins + skip_execution_of_underlying_join: Whether to skip execution of + underlying joins Returns: DataFrame with the execution results """ - start_date: str = start_date or self.default_start_date end_date: str = end_date or self.default_end_date - self.platform.log_operation(f"Executing GroupBy {self.obj.metaData.name} from {start_date} to {end_date} with step_days {step_days}") - self.platform.log_operation(f"Skip Execution of Underlying Join Sources: {skip_execution_of_underlying_join}") + self.platform.log_operation( + f"Executing GroupBy {self.obj.metaData.name} from " + f"{start_date} to {end_date} with step_days {step_days}" + ) + self.platform.log_operation( + f"Skip Execution of Underlying Join Sources: " + f"{skip_execution_of_underlying_join}" + ) if not skip_execution_of_underlying_join: - self._execute_underlying_join_sources(group_bys=[self.obj], start_date=start_date, end_date=end_date, step_days=step_days) + self._execute_underlying_join_sources( + group_bys=[self.obj], + start_date=start_date, + end_date=end_date, + step_days=step_days + ) # Prepare GroupBy for execution group_by_to_execute: GroupBy = copy.deepcopy(self.obj) group_by_to_execute.backfillStartDate = start_date # Update sources with correct dates - group_by_to_execute: GroupBy = self._update_source_dates_for_group_by(group_by_to_execute, start_date, end_date) + group_by_to_execute = self._update_source_dates_for_group_by( + group_by_to_execute, start_date, end_date + ) # Get output table name - group_by_output_table: str = output_table_name(group_by_to_execute, full_name=True) + group_by_output_table: str = output_table_name( + group_by_to_execute, full_name=True + ) - # GroupBy backfills don't store the semantic hash as a property in the table the same way joins do. + # GroupBy backfills don't store the semantic hash as a property in + # the table the same way joins do. # Therefore we drop the backfill table to avoid data quality issues. self.platform.drop_table_if_exists(table_name=group_by_output_table) # Find starting point for log capture just before executing JVM calls - log_token = self.platform.start_log_capture(f"Run GroupBy: {self.obj.metaData.name}") + log_token = self.platform.start_log_capture( + f"Run GroupBy: {self.obj.metaData.name}" + ) try: # Convert to Java GroupBy java_group_by = self.group_by_to_java(group_by_to_execute) # Execute GroupBy - result_df_scala: JavaObject = self.jvm.ai.chronon.spark.PySparkUtils.runGroupBy( - java_group_by, - end_date, - self.jvm.ai.chronon.spark.PySparkUtils.getIntOptional(str(step_days)), - self.platform.get_table_utils(), - self.platform.get_constants_provider() + result_df_scala: JavaObject = ( + self.jvm.ai.chronon.spark.PySparkUtils.runGroupBy( + java_group_by, + end_date, + self.jvm.ai.chronon.spark.PySparkUtils.getIntOptional( + str(step_days)), + self.platform.get_table_utils(), + self.platform.get_constants_provider() + ) ) result_df = DataFrame(result_df_scala, self.spark) self.platform.end_log_capture(log_token) - self.platform.log_operation(f"GroupBy {self.obj.metaData.name} executed successfully and was written to {group_by_output_table}") + self.platform.log_operation( + f"GroupBy {self.obj.metaData.name} executed successfully " + f"and was written to {group_by_output_table}" + ) return result_df except Exception as e: self.platform.end_log_capture(log_token) - self.platform.log_operation(f"Execution failed for GroupBy {self.obj.metaData.name}: {str(e)}") + self.platform.log_operation( + f"Execution failed for GroupBy {self.obj.metaData.name}: {str(e)}" + ) raise e - def analyze(self, - start_date: str | None = None, - end_date: str | None = None, - enable_hitter_analysis: bool = False) -> None: + def analyze( + self, + start_date: str | None = None, + end_date: str | None = None, + enable_hitter_analysis: bool = False + ) -> None: """ Analyze the GroupBy object. @@ -296,17 +367,26 @@ def analyze(self, start_date = start_date or self.default_start_date end_date = end_date or self.default_end_date - self.platform.log_operation(f"Analyzing GroupBy {self.obj.metaData.name} from {start_date} to {end_date}") - self.platform.log_operation(f"Enable Hitter Analysis: {enable_hitter_analysis}") + self.platform.log_operation( + f"Analyzing GroupBy {self.obj.metaData.name} from " + f"{start_date} to {end_date}" + ) + self.platform.log_operation( + f"Enable Hitter Analysis: {enable_hitter_analysis}" + ) # Prepare GroupBy for analysis group_by_to_analyze: GroupBy = copy.deepcopy(self.obj) # Update sources with correct dates - group_by_to_analyze: GroupBy = self._update_source_dates_for_group_by(group_by_to_analyze, start_date, end_date) + group_by_to_analyze = self._update_source_dates_for_group_by( + group_by_to_analyze, start_date, end_date + ) # Start log capture just before executing JVM calls - log_token = self.platform.start_log_capture(f"Analyze GroupBy: {self.obj.metaData.name}") + log_token = self.platform.start_log_capture( + f"Analyze GroupBy: {self.obj.metaData.name}" + ) try: # Convert to Java GroupBy @@ -321,16 +401,21 @@ def analyze(self, self.platform.get_constants_provider() ) self.platform.end_log_capture(log_token) - self.platform.log_operation(f"GroupBy {self.obj.metaData.name} analyzed successfully") + self.platform.log_operation( + f"GroupBy {self.obj.metaData.name} analyzed successfully" + ) except Exception as e: self.platform.end_log_capture(log_token) - self.platform.log_operation(f"Analysis failed for GroupBy {self.obj.metaData.name}: {str(e)}") + self.platform.log_operation( + f"Analysis failed for GroupBy {self.obj.metaData.name}: {str(e)}" + ) raise e - - def validate(self, - start_date: str | None = None, - end_date: str | None = None) -> None: + def validate( + self, + start_date: str | None = None, + end_date: str | None = None + ) -> None: """ Validate the GroupBy object. @@ -338,20 +423,26 @@ def validate(self, start_date: Start date for validation (format: YYYYMMDD) end_date: End date for validation (format: YYYYMMDD) """ - platform = self.get_platform() start_date = start_date or self.default_start_date end_date = end_date or self.default_end_date - self.platform.log_operation(f"Validating GroupBy {self.obj.metaData.name} from {start_date} to {end_date}") + self.platform.log_operation( + f"Validating GroupBy {self.obj.metaData.name} from " + f"{start_date} to {end_date}" + ) # Prepare GroupBy for validation group_by_to_validate = copy.deepcopy(self.obj) # Update sources with correct dates - group_by_to_validate: GroupBy = self._update_source_dates_for_group_by(group_by_to_validate, start_date, end_date) + group_by_to_validate = self._update_source_dates_for_group_by( + group_by_to_validate, start_date, end_date + ) # Start log capture just before executing JVM calls - log_token = self.platform.start_log_capture(f"Validate GroupBy: {self.obj.metaData.name}") + log_token = self.platform.start_log_capture( + f"Validate GroupBy: {self.obj.metaData.name}" + ) try: # Convert to Java GroupBy @@ -366,23 +457,26 @@ def validate(self, ) self.platform.end_log_capture(log_token) - self.platform.handle_validation_errors(errors_list, f"GroupBy {self.obj.metaData.name}") - self.platform.log_operation(f"Validation for GroupBy {self.obj.metaData.name} has completed") + self.platform.handle_validation_errors( + errors_list, f"GroupBy {self.obj.metaData.name}" + ) + self.platform.log_operation( + f"Validation for GroupBy {self.obj.metaData.name} has completed" + ) except Exception as e: - self.platform.log_operation(f"Validation failed for GroupBy {self.obj.metaData.name}: {str(e)}") + self.platform.log_operation( + f"Validation failed for GroupBy {self.obj.metaData.name}: {str(e)}" + ) self.platform.end_log_capture(log_token) raise e - - - - class JoinExecutable(PySparkExecutable[Join], ABC): """Interface for executing Join objects""" - - def _update_source_dates_for_join_parts(self, join_parts: list[JoinPart], start_date: str, end_date: str) -> list[JoinPart]: + def _update_source_dates_for_join_parts( + self, join_parts: list[JoinPart], start_date: str, end_date: str + ) -> list[JoinPart]: """ Update start and end dates of sources in JoinParts. @@ -399,16 +493,20 @@ def _update_source_dates_for_join_parts(self, join_parts: list[JoinPart], start_ for jp in join_parts: for i, source in enumerate(jp.groupBy.sources): - jp.groupBy.sources[i] = self._update_source_dates(source, start_date, end_date) + jp.groupBy.sources[i] = self._update_source_dates( + source, start_date, end_date + ) return join_parts - def run(self, - start_date: str | None = None, - end_date: str | None = None, - step_days: int = 30, - skip_first_hole: bool = False, - sample_num_of_rows: int | None = None, - skip_execution_of_underlying_join: bool = False) -> DataFrame: + def run( + self, + start_date: str | None = None, + end_date: str | None = None, + step_days: int = 30, + skip_first_hole: bool = False, + sample_num_of_rows: int | None = None, + skip_execution_of_underlying_join: bool = False + ) -> DataFrame: """ Execute the Join object with Join-specific parameters. @@ -418,7 +516,8 @@ def run(self, step_days: Number of days to process in each step skip_first_hole: Whether to skip the first hole in the join sample_num_of_rows: Number of rows to sample (None for all) - skip_execution_of_underlying_join: Whether to skip execution of underlying joins + skip_execution_of_underlying_join: Whether to skip execution of + underlying joins Returns: DataFrame with the execution results @@ -426,18 +525,27 @@ def run(self, start_date = start_date or self.default_start_date end_date = end_date or self.default_end_date - self.platform.log_operation(f"Executing Join {self.obj.metaData.name} from {start_date} to {end_date} with step_days {step_days}") + self.platform.log_operation( + f"Executing Join {self.obj.metaData.name} from " + f"{start_date} to {end_date} with step_days {step_days}" + ) self.platform.log_operation(f"Skip First Hole: {skip_first_hole}") self.platform.log_operation(f"Sample Number of Rows: {sample_num_of_rows}") - self.platform.log_operation(f"Skip Execution of Underlying Join: {skip_execution_of_underlying_join}") + self.platform.log_operation( + f"Skip Execution of Underlying Join: {skip_execution_of_underlying_join}" + ) # Prepare Join for execution join_to_execute = copy.deepcopy(self.obj) - join_to_execute.left = self._update_source_dates(join_to_execute.left, start_date, end_date) + join_to_execute.left = self._update_source_dates( + join_to_execute.left, start_date, end_date + ) if not skip_execution_of_underlying_join and self.obj.joinParts: - self._execute_underlying_join_sources([jp.groupBy for jp in join_to_execute.joinParts], start_date, end_date, step_days) - + self._execute_underlying_join_sources( + [jp.groupBy for jp in join_to_execute.joinParts], + start_date, end_date, step_days + ) # Update join parts sources join_to_execute.joinParts = self._update_source_dates_for_join_parts( @@ -448,7 +556,9 @@ def run(self, join_output_table = output_table_name(join_to_execute, full_name=True) # Start log capture just before executing JVM calls - log_token = self.platform.start_log_capture(f"Run Join: {self.obj.metaData.name}") + log_token = self.platform.start_log_capture( + f"Run Join: {self.obj.metaData.name}" + ) try: # Convert to Java Join @@ -457,26 +567,37 @@ def run(self, result_df_scala: JavaObject = self.jvm.ai.chronon.spark.PySparkUtils.runJoin( java_join, end_date, - self.jvm.ai.chronon.spark.PySparkUtils.getIntOptional(str(step_days)), + self.jvm.ai.chronon.spark.PySparkUtils.getIntOptional( + str(step_days) + ), skip_first_hole, - self.jvm.ai.chronon.spark.PySparkUtils.getIntOptional(None if not sample_num_of_rows else str(sample_num_of_rows)), + self.jvm.ai.chronon.spark.PySparkUtils.getIntOptional( + None if not sample_num_of_rows else str(sample_num_of_rows) + ), self.platform.get_table_utils(), self.platform.get_constants_provider() ) result_df = DataFrame(result_df_scala, self.spark) self.platform.end_log_capture(log_token) - self.platform.log_operation(f"Join {self.obj.metaData.name} executed successfully and was written to {join_output_table}") + self.platform.log_operation( + f"Join {self.obj.metaData.name} executed successfully and " + f"was written to {join_output_table}" + ) return result_df except Exception as e: self.platform.end_log_capture(log_token) - self.platform.log_operation(f"Execution failed for Join {self.obj.metaData.name}: {str(e)}") + self.platform.log_operation( + f"Execution failed for Join {self.obj.metaData.name}: {str(e)}" + ) raise e - def analyze(self, - start_date: str | None = None, - end_date: str | None = None, - enable_hitter_analysis: bool = False) -> None: + def analyze( + self, + start_date: str | None = None, + end_date: str | None = None, + enable_hitter_analysis: bool = False + ) -> None: """ Analyze the Join object. @@ -488,12 +609,19 @@ def analyze(self, start_date: str = start_date or self.default_start_date end_date: str = end_date or self.default_end_date - self.platform.log_operation(f"Analyzing Join {self.obj.metaData.name} from {start_date} to {end_date}") - self.platform.log_operation(f"Enable Hitter Analysis: {enable_hitter_analysis}") + self.platform.log_operation( + f"Analyzing Join {self.obj.metaData.name} from " + f"{start_date} to {end_date}" + ) + self.platform.log_operation( + f"Enable Hitter Analysis: {enable_hitter_analysis}" + ) # Prepare Join for analysis join_to_analyze: Join = copy.deepcopy(self.obj) - join_to_analyze.left = self._update_source_dates(join_to_analyze.left, start_date, end_date) + join_to_analyze.left = self._update_source_dates( + join_to_analyze.left, start_date, end_date + ) # Update join parts sources join_to_analyze.joinParts = self._update_source_dates_for_join_parts( @@ -501,7 +629,9 @@ def analyze(self, ) # Start log capture just before executing JVM calls - log_token = self.platform.start_log_capture(f"Analyze Join: {self.obj.metaData.name}") + log_token = self.platform.start_log_capture( + f"Analyze Join: {self.obj.metaData.name}" + ) try: # Convert to Java Join @@ -516,19 +646,22 @@ def analyze(self, self.platform.get_constants_provider() ) self.platform.end_log_capture(log_token) - self.platform.log_operation(f"Join {self.obj.metaData.name} analyzed successfully") + self.platform.log_operation( + f"Join {self.obj.metaData.name} analyzed successfully" + ) except Exception as e: self.platform.end_log_capture(log_token) - self.platform.log_operation(f"Analysis failed for Join {self.obj.metaData.name}: {str(e)}") + self.platform.log_operation( + f"Analysis failed for Join {self.obj.metaData.name}: {str(e)}" + ) raise e - - - - def validate(self, - start_date: str | None = None, - end_date: str | None = None) -> None: + def validate( + self, + start_date: str | None = None, + end_date: str | None = None + ) -> None: """ Validate the Join object. @@ -539,11 +672,16 @@ def validate(self, start_date: str = start_date or self.default_start_date end_date: str = end_date or self.default_end_date - self.platform.log_operation(f"Validating Join {self.obj.metaData.name} from {start_date} to {end_date}") + self.platform.log_operation( + f"Validating Join {self.obj.metaData.name} from " + f"{start_date} to {end_date}" + ) # Prepare Join for validation join_to_validate: Join = copy.deepcopy(self.obj) - join_to_validate.left = self._update_source_dates(join_to_validate.left, start_date, end_date) + join_to_validate.left = self._update_source_dates( + join_to_validate.left, start_date, end_date + ) # Update join parts sources join_to_validate.joinParts = self._update_source_dates_for_join_parts( @@ -551,7 +689,9 @@ def validate(self, ) # Start log capture just before executing JVM calls - log_token = self.platform.start_log_capture(f"Validate Join: {self.obj.metaData.name}") + log_token = self.platform.start_log_capture( + f"Validate Join: {self.obj.metaData.name}" + ) try: # Convert to Java Join @@ -567,16 +707,20 @@ def validate(self, self.platform.end_log_capture(log_token) # Handle validation errors - self.platform.handle_validation_errors(errors_list, f"Join {self.obj.metaData.name}") - self.platform.log_operation(f"Validation for Join {self.obj.metaData.name} has completed") + self.platform.handle_validation_errors( + errors_list, f"Join {self.obj.metaData.name}" + ) + self.platform.log_operation( + f"Validation for Join {self.obj.metaData.name} has completed" + ) except Exception as e: self.platform.end_log_capture(log_token) - self.platform.log_operation(f"Validation failed for Join {self.obj.metaData.name}: {str(e)}") + self.platform.log_operation( + f"Validation failed for Join {self.obj.metaData.name}: {str(e)}" + ) raise e - - class PlatformInterface(ABC): """ Interface for platform-specific operations. @@ -654,14 +798,21 @@ def register_udfs(self) -> None: """ Register UDFs for the self.platform. - This method is intentionally left empty but not abstract, as some platforms may not need to register UDFs. + This method is intentionally left empty but not abstract, as some + platforms may not need to register UDFs. - Subclasses can override this method to provide platform-specific UDF registration. + Subclasses can override this method to provide platform-specific UDF + registration. - Pro tip: Both the JVM Spark Session and Python Spark Session use the same spark-sql engine. You can register Python UDFS and use them in the JVM, as well as vice-versa. - At Stripe we currently only use Scala UDFs, so we include a JAR of our UDFs in the cluster and register them via: + Pro tip: Both the JVM Spark Session and Python Spark Session use the + same spark-sql engine. You can register Python UDFS and use them in + the JVM, as well as vice-versa. + At Stripe we currently only use Scala UDFs, so we include a JAR of our + UDFs in the cluster and register them via: - self.jvm.com.stripe.shepherd.udfs.ShepherdUdfRegistry.registerAllUdfs(self.java_spark_session) + self.jvm.com.stripe.shepherd.udfs.ShepherdUdfRegistry.registerAllUdfs( + self.java_spark_session + ) """ pass @@ -685,7 +836,6 @@ def drop_table_if_exists(self, table_name: str) -> None: """ _ = self.spark.sql(f"DROP TABLE IF EXISTS {table_name}") - def handle_validation_errors(self, errors: JavaObject, object_name: str) -> None: """ Handle validation errors. @@ -695,13 +845,21 @@ def handle_validation_errors(self, errors: JavaObject, object_name: str) -> None object_name: Name of the object being validated """ if errors.length() > 0: - self.log_operation(message=f"Validation failed for {object_name} with the following errors:") + self.log_operation( + message=f"Validation failed for {object_name} " + + "with the following errors:" + ) self.log_operation(message=str(errors)) else: self.log_operation(message=f"Validation passed for {object_name}.") - def set_metadata(self, obj: GroupBy | Join | StagingQuery, mod_prefix: str, - name_prefix: str | None = None, output_namespace: str | None = None) -> T: + def set_metadata( + self, + obj: GroupBy | Join | StagingQuery, + mod_prefix: str, + name_prefix: str | None = None, + output_namespace: str | None = None + ) -> T: """ Set the metadata for the object. @@ -721,7 +879,10 @@ def set_metadata(self, obj: GroupBy | Join | StagingQuery, mod_prefix: str, try: set_name(obj=obj, cls=obj_type, mod_prefix=mod_prefix) except AttributeError: - raise AttributeError("Please provide a name when defining group_bys/joins/staging_queries adhoc.") + raise AttributeError( + "Please provide a name when defining " + "group_bys/joins/staging_queries adhoc." + ) # We do this to avoid adding the name prefix multiple times elif obj.metaData.name and name_prefix: obj.metaData.name = obj.metaData.name.replace(f"{name_prefix}_", "") @@ -736,7 +897,9 @@ def set_metadata(self, obj: GroupBy | Join | StagingQuery, mod_prefix: str, for jp in obj.joinParts: for s in jp.groupBy.sources: if s.joinSource and not s.joinSource.join.metaData.name: - set_name(obj=s.joinSource.join, cls=Join, mod_prefix=mod_prefix) + set_name( + obj=s.joinSource.join, cls=Join, mod_prefix=mod_prefix + ) # Set output namespace if output_namespace: From f5c83eb2c83d08b9b9163531e62a1ef2c66ae4f8 Mon Sep 17 00:00:00 2001 From: camweston-stripe <116691078+camweston-stripe@users.noreply.github.com> Date: Thu, 24 Apr 2025 17:41:57 -0700 Subject: [PATCH 14/31] Reformat executables to be flake8 compliant --- api/py/ai/chronon/pyspark/executables.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/api/py/ai/chronon/pyspark/executables.py b/api/py/ai/chronon/pyspark/executables.py index b12ed498d3..305009eeeb 100644 --- a/api/py/ai/chronon/pyspark/executables.py +++ b/api/py/ai/chronon/pyspark/executables.py @@ -854,10 +854,10 @@ def handle_validation_errors(self, errors: JavaObject, object_name: str) -> None self.log_operation(message=f"Validation passed for {object_name}.") def set_metadata( - self, - obj: GroupBy | Join | StagingQuery, + self, + obj: GroupBy | Join | StagingQuery, mod_prefix: str, - name_prefix: str | None = None, + name_prefix: str | None = None, output_namespace: str | None = None ) -> T: """ @@ -914,4 +914,3 @@ def set_metadata( obj.metaData.name = f"{name_prefix}_{obj.metaData.name}" return obj - From 5346d5b178fb31be4839cd6929a32910e351806f Mon Sep 17 00:00:00 2001 From: camweston-stripe <116691078+camweston-stripe@users.noreply.github.com> Date: Thu, 24 Apr 2025 17:43:19 -0700 Subject: [PATCH 15/31] Reformat executables to be flake8 compliant --- api/py/ai/chronon/pyspark/executables.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/py/ai/chronon/pyspark/executables.py b/api/py/ai/chronon/pyspark/executables.py index 305009eeeb..d750228a77 100644 --- a/api/py/ai/chronon/pyspark/executables.py +++ b/api/py/ai/chronon/pyspark/executables.py @@ -881,7 +881,7 @@ def set_metadata( except AttributeError: raise AttributeError( "Please provide a name when defining " - "group_bys/joins/staging_queries adhoc." + + "group_bys/joins/staging_queries adhoc." ) # We do this to avoid adding the name prefix multiple times elif obj.metaData.name and name_prefix: From a69a5b0ad25315e6eceb3619318618285102eafc Mon Sep 17 00:00:00 2001 From: camweston-stripe <116691078+camweston-stripe@users.noreply.github.com> Date: Thu, 24 Apr 2025 17:46:05 -0700 Subject: [PATCH 16/31] Reformat utils to be flake8 compliant --- api/py/ai/chronon/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/api/py/ai/chronon/utils.py b/api/py/ai/chronon/utils.py index 99741862ca..a137ef1847 100644 --- a/api/py/ai/chronon/utils.py +++ b/api/py/ai/chronon/utils.py @@ -603,6 +603,7 @@ def get_max_window_for_gb_in_days(group_by: api.GroupBy) -> int: ) else: raise ValueError( - f"Unsupported time unit {window.timeUnit}. Please add logic above to handle the newly introduced time unit." + f"Unsupported time unit {window.timeUnit}. " + + "Please add logic above to handle the newly introduced time unit." ) - return result \ No newline at end of file + return result From 6f72c4eee870db696ccf1fb6b127863b1f9cbe1e Mon Sep 17 00:00:00 2001 From: camweston-stripe <116691078+camweston-stripe@users.noreply.github.com> Date: Thu, 24 Apr 2025 17:51:39 -0700 Subject: [PATCH 17/31] Remove stripe specific logic from Databricks Table Utils --- .../databricks/DatabricksTableUtils.scala | 115 ++---------------- 1 file changed, 8 insertions(+), 107 deletions(-) diff --git a/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala b/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala index 5efc9e0708..4a5223da9f 100644 --- a/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala @@ -5,111 +5,12 @@ import ai.chronon.spark.BaseTableUtils import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{Column, SparkSession} -case class DatabricksTableUtils (override val sparkSession: SparkSession) extends BaseTableUtils { - private def isTempView(tableName: String): Boolean = - sparkSession.sessionState.catalog.isTempView(tableName.split('.')) - - override def createTableSql(tableName: String, schema: StructType, partitionColumns: Seq[String], tableProperties: Map[String, String], fileFormat: String): String = { - - // Side effect: Creates the table in iceberg and then confirms the table creation was successfully by issuing a SHOW CREATE TABLE. - val partitionSparkCols: Seq[Column] = partitionColumns.map(col => org.apache.spark.sql.functions.col(col)) - - val emptyDf = sparkSession.createDataFrame(sparkSession.sparkContext.emptyRDD[org.apache.spark.sql.Row], schema) - - val createTableWriter = { - if(Option(tableProperties).exists(_.nonEmpty)) { - tableProperties - .foldLeft(emptyDf.writeTo(tableName).using("iceberg"))((createTableWriter, tblProperty) => createTableWriter.tableProperty(tblProperty._1, tblProperty._2)) - } - else { - emptyDf.writeTo(tableName).using("iceberg") - } - } - - - partitionSparkCols match { - case Seq() => - logger.info(s"Creating table $tableName without partitioning") - createTableWriter.create() - case Seq(head) => - logger.info(s"Creating table $tableName partitioned by $head") - createTableWriter.partitionedBy(head).create() - case head +: tail => - logger.info(s"Creating table $tableName partitioned by $head and $tail") - createTableWriter.partitionedBy(head, tail: _*).create() - } - - s"SHOW CREATE TABLE $tableName" // Confirm table creation in subsequent call - } - - override def getIcebergPartitions(tableName: String): Seq[String] = { - val partitionsDf = sparkSession.read.format("iceberg").load(s"$tableName.partitions") - val index = partitionsDf.schema.fieldIndex("partition") - if (partitionsDf.schema(index).dataType.asInstanceOf[StructType].fieldNames.contains("hr")) { - // Hour filter is currently buggy in iceberg. https://github.com/apache/iceberg/issues/4718 - // so we collect and then filter. - partitionsDf - .select(s"partition.${partitionColumn}", s"partition.${Constants.HourPartitionColumn}") - .collect() - .filter(_.get(1) == null) - .map(_.getString(0)) - .toSeq - } else if (partitionsDf.schema(index).dataType.asInstanceOf[StructType].fieldNames.contains("locality_zone")) { - partitionsDf - .select(s"partition.${partitionColumn}") - .where("partition.locality_zone == 'DEFAULT'") - .collect() - .map(_.getString(0)) - .toSeq - } - else { - partitionsDf - .select(s"partition.${partitionColumn}") - .collect() - .map(_.getString(0)) - .toSeq - } - } - - override def partitions( - tableName: String, - subPartitionsFilter: Map[String, String] = Map.empty, - partitionColumnOverride: String = Constants.PartitionColumn - ): Seq[String] = - // This is to support s3 prefix inputs. - if (isTempView(tableName)) { - if (subPartitionsFilter.nonEmpty) { - throw new NotImplementedError( - s"partitions cannot be called with tableName ${tableName} subPartitionsFilter ${subPartitionsFilter} because subPartitionsFilter is not supported on tempViews yet." - ) - } - // If the table is a temp view, fallback to querying for the distinct partition columns because SHOW PARTITIONS - // doesn't work on temp views. This can be inefficient for large tables. - logger.info( - s"Selecting partitions for temp view table tableName=$tableName, " + - s"partitionColumnOverride=$partitionColumnOverride." - ) - val outputPartitionsRowsDf = sql( - s"select distinct ${partitionColumnOverride} from $tableName" - ) - val outputPartitionsRows = outputPartitionsRowsDf.collect() - // Users have ran into issues where the partition column is not a string, so add logging to facilitate debug. - logger.info( - s"Found ${outputPartitionsRows.length} partitions for temp view table tableName=$tableName. The partition schema is ${outputPartitionsRowsDf.schema}." - ) - - val outputPartitionsRowsDistinct = outputPartitionsRows.map(_.getString(0)).distinct - logger.info( - s"Converted ${outputPartitionsRowsDistinct.length} distinct partitions for temp view table to String. " + - s"tableName=$tableName." - ) - - outputPartitionsRowsDistinct - } else { - super.partitions(tableName, subPartitionsFilter, partitionColumnOverride) - } - - - -} \ No newline at end of file +/** + * DatabricksTableUtils is the table utils class used in our Databricks integration. + * If you need any specific functionality pertaining to reads/writes for your Databricks setup, + * you can implement it here. + * + * @param sparkSession The Spark session used for table operations. + */ +case class DatabricksTableUtils (override val sparkSession: SparkSession) extends BaseTableUtils \ No newline at end of file From a0155836edb9ae0fdb83c69c98194a0e565200a4 Mon Sep 17 00:00:00 2001 From: camweston-stripe <116691078+camweston-stripe@users.noreply.github.com> Date: Thu, 24 Apr 2025 17:52:00 -0700 Subject: [PATCH 18/31] Remove stripe specific logic from Databricks Table Utils --- .../ai/chronon/spark/databricks/DatabricksTableUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala b/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala index 4a5223da9f..3e52a2f1ed 100644 --- a/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala @@ -13,4 +13,4 @@ import org.apache.spark.sql.{Column, SparkSession} * * @param sparkSession The Spark session used for table operations. */ -case class DatabricksTableUtils (override val sparkSession: SparkSession) extends BaseTableUtils \ No newline at end of file +case class DatabricksTableUtils (override val sparkSession: SparkSession) extends BaseTableUtils From 6792e0edfd1aaeb9629f7582fb8be1cc8f310c6b Mon Sep 17 00:00:00 2001 From: camweston-stripe <116691078+camweston-stripe@users.noreply.github.com> Date: Thu, 24 Apr 2025 17:57:12 -0700 Subject: [PATCH 19/31] add explanation to databricks constants provider --- .../databricks/DatabricksConstantsNameProvider.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksConstantsNameProvider.scala b/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksConstantsNameProvider.scala index 6d3566e806..8535f970eb 100644 --- a/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksConstantsNameProvider.scala +++ b/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksConstantsNameProvider.scala @@ -3,10 +3,14 @@ package ai.chronon.spark.databricks import ai.chronon.api.Extensions.{WindowOps, WindowUtils} import ai.chronon.api.{ConstantNameProvider, PartitionSpec} +/** + * DatabricksConstantsNameProvider provides JVM constants used in our Databricks integration. + * If you need any specific functionality pertaining to your Databricks JVM execution, + * you can implement it here. + */ class DatabricksConstantsNameProvider extends ConstantNameProvider with Serializable { override def TimeColumn: String = "_internal_time_column" override def DatePartitionColumn: String = "day" - override def HourPartitionColumn: String = "hr" override def Partition: PartitionSpec = PartitionSpec(format = "yyyyMMdd", spanMillis = WindowUtils.Day.millis) -} \ No newline at end of file +} From 6e859e0c6e5840e4590b7a3d9f82bbe5fd952eab Mon Sep 17 00:00:00 2001 From: camweston-stripe <116691078+camweston-stripe@users.noreply.github.com> Date: Thu, 24 Apr 2025 18:32:54 -0700 Subject: [PATCH 20/31] Run scalafmt on new scala code --- .../scala/ai/chronon/spark/PySparkUtils.scala | 78 ++++++++++--------- .../DatabricksConstantsNameProvider.scala | 3 + .../databricks/DatabricksTableUtils.scala | 2 +- 3 files changed, 47 insertions(+), 36 deletions(-) diff --git a/spark/src/main/scala/ai/chronon/spark/PySparkUtils.scala b/spark/src/main/scala/ai/chronon/spark/PySparkUtils.scala index 387e2c3300..f1ce507a49 100644 --- a/spark/src/main/scala/ai/chronon/spark/PySparkUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/PySparkUtils.scala @@ -1,4 +1,5 @@ package ai.chronon.spark + import ai.chronon.aggregator.windowing.{FiveMinuteResolution, Resolution} import ai.chronon.api import ai.chronon.api.Extensions.MetadataOps @@ -11,6 +12,7 @@ object PySparkUtils { /** * Pyspark has a tough time creating a FiveMinuteResolution via jvm.ai.chronon.aggregator.windowing.FiveMinuteResolution so we provide this helper method + * * @return FiveMinuteResolution */ @@ -18,6 +20,7 @@ object PySparkUtils { /** * Creating optionals is difficult to support in Pyspark, so we provide this method as a work around + * * @param timeRange a time range * @return Empty time range optional */ @@ -25,30 +28,34 @@ object PySparkUtils { /** * Creating optionals is difficult to support in Pyspark, so we provide this method as a work around + * * @param str a string * @return String optional */ - def getStringOptional(str : String) : Option[String] = if (str == null) Option.empty[String] else Option(str) + def getStringOptional(str: String): Option[String] = if (str == null) Option.empty[String] else Option(str) /** * Creating optionals is difficult to support in Pyspark, so we provide this method as a work around * Furthermore, ints can't be null in Scala so we need to pass the value in as a str + * * @param strInt a string * @return Int optional */ - def getIntOptional(strInt : String) : Option[Int] = if (strInt == null) Option.empty[Int] else Option(strInt.toInt) + def getIntOptional(strInt: String): Option[Int] = if (strInt == null) Option.empty[Int] else Option(strInt.toInt) /** * Type parameters are difficult to support in Pyspark, so we provide these helper methods for ThriftJsonCodec.fromJsonStr + * * @param groupByJson a JSON string representing a group by * @return Chronon Scala API GroupBy object */ def parseGroupBy(groupByJson: String): api.GroupBy = { - ThriftJsonCodec.fromJsonStr[api.GroupBy](groupByJson, check = true, classOf[api.GroupBy]) + ThriftJsonCodec.fromJsonStr[api.GroupBy](groupByJson, check = true, classOf[api.GroupBy]) } /** * Type parameters are difficult to support in Pyspark, so we provide these helper methods for ThriftJsonCodec.fromJsonStr + * * @param joinJson a JSON string representing a join * @return Chronon Scala API Join object */ @@ -58,6 +65,7 @@ object PySparkUtils { /** * Type parameters are difficult to support in Pyspark, so we provide these helper methods for ThriftJsonCodec.fromJsonStr + * * @param sourceJson a JSON string representing a source. * @return Chronon Scala API Source object */ @@ -78,14 +86,14 @@ object PySparkUtils { /** * Helper function to allow a user to execute a Group By. * - * @param groupByConf api.GroupBy Chronon scala GroupBy API object - * @param endDate str this represents the last date we will perform the aggregation for - * @param stepDays int this will determine how we chunk filling the missing partitions - * @param tableUtils TableUtils this will be used to perform ops against our data sources + * @param groupByConf api.GroupBy Chronon scala GroupBy API object + * @param endDate str this represents the last date we will perform the aggregation for + * @param stepDays int this will determine how we chunk filling the missing partitions + * @param tableUtils TableUtils this will be used to perform ops against our data sources * @param constantsProvider ConstantsProvider must be set from the Scala side. Doing so from PySpark will not properly set it. * @return DataFrame */ - def runGroupBy(groupByConf: api.GroupBy, endDate: String, stepDays: Option[Int], tableUtils: BaseTableUtils, constantsProvider: ConstantNameProvider) : DataFrame = { + def runGroupBy(groupByConf: api.GroupBy, endDate: String, stepDays: Option[Int], tableUtils: BaseTableUtils, constantsProvider: ConstantNameProvider): DataFrame = { logger.info(s"Executing GroupBy: ${groupByConf.metaData.name}") Constants.initConstantNameProvider(constantsProvider) GroupBy.computeBackfill( @@ -101,10 +109,10 @@ object PySparkUtils { /** * Helper function to allow a user to execute a Join. * - * @param joinConf api.Join Chronon scala Join API object - * @param endDate str this represents the last date we will perform the Join for - * @param stepDays int this will determine how we chunk filling the missing partitions - * @param tableUtils TableUtils this will be used to perform ops against our data sources + * @param joinConf api.Join Chronon scala Join API object + * @param endDate str this represents the last date we will perform the Join for + * @param stepDays int this will determine how we chunk filling the missing partitions + * @param tableUtils TableUtils this will be used to perform ops against our data sources * @param constantsProvider ConstantsProvider must be set from the Scala side. Doing so from PySpark will not properly set it. * @return DataFrame */ @@ -115,7 +123,7 @@ object PySparkUtils { sampleNumOfRows: Option[Int], tableUtils: BaseTableUtils, constantsProvider: ConstantNameProvider - ) : DataFrame = { + ): DataFrame = { logger.info(s"Executing Join ${joinConf.metaData.name}") Constants.initConstantNameProvider(constantsProvider) val join = new Join( @@ -132,14 +140,14 @@ object PySparkUtils { /** * Helper function to validate a GroupBy * - * @param groupByConf api.GroupBy Chronon scala GroupBy API object - * @param startDate start date for the group by - * @param endDate end date for the group by - * @param tableUtils TableUtils this will be used to perform ops against our data sources + * @param groupByConf api.GroupBy Chronon scala GroupBy API object + * @param startDate start date for the group by + * @param endDate end date for the group by + * @param tableUtils TableUtils this will be used to perform ops against our data sources * @param constantsProvider ConstantsProvider must be set from the Scala side. Doing so from PySpark will not properly set it. * @return DataFrame */ - def validateGroupBy(groupByConf: api.GroupBy, startDate: String, endDate: String, tableUtils: BaseTableUtils, constantsProvider: ConstantNameProvider) : List[String] = { + def validateGroupBy(groupByConf: api.GroupBy, startDate: String, endDate: String, tableUtils: BaseTableUtils, constantsProvider: ConstantNameProvider): List[String] = { logger.info(s"Validating GroupBy ${groupByConf.metaData.name}") Constants.initConstantNameProvider(constantsProvider) val validator = new Validator(tableUtils, groupByConf, startDate, endDate) @@ -152,14 +160,14 @@ object PySparkUtils { /** * Helper function to validate a Join * - * @param joinConf api.Join Chronon scala Join API object - * @param startDate start date for the join - * @param endDate end date for the join - * @param tableUtils TableUtils this will be used to perform ops against our data sources + * @param joinConf api.Join Chronon scala Join API object + * @param startDate start date for the join + * @param endDate end date for the join + * @param tableUtils TableUtils this will be used to perform ops against our data sources * @param constantsProvider ConstantsProvider must be set from the Scala side. Doing so from PySpark will not properly set it. * @return DataFrame */ - def validateJoin(joinConf: api.Join, startDate: String, endDate: String, tableUtils: BaseTableUtils, constantsProvider: ConstantNameProvider) : List[String] = { + def validateJoin(joinConf: api.Join, startDate: String, endDate: String, tableUtils: BaseTableUtils, constantsProvider: ConstantNameProvider): List[String] = { logger.info(s"Validating Join: ${joinConf.metaData.name}") Constants.initConstantNameProvider(constantsProvider) val validator = new Validator(tableUtils, joinConf, startDate, endDate) @@ -171,14 +179,14 @@ object PySparkUtils { /** * Helper function to analyze a GroupBy * - * @param groupByConf api.GroupBy Chronon scala GroupBy API object - * @param startDate start date for the group by - * @param endDate end date for the group by + * @param groupByConf api.GroupBy Chronon scala GroupBy API object + * @param startDate start date for the group by + * @param endDate end date for the group by * @param enableHitterAnalysis if true we will perform an analysis of what hot keys may be present - * @param tableUtils TableUtils this will be used to perform ops against our data sources - * @param constantsProvider ConstantsProvider must be set from the Scala side. Doing so from PySpark will not properly set it. + * @param tableUtils TableUtils this will be used to perform ops against our data sources + * @param constantsProvider ConstantsProvider must be set from the Scala side. Doing so from PySpark will not properly set it. */ - def analyzeGroupBy(groupByConf: api.GroupBy, startDate: String, endDate: String, enableHitterAnalysis: Boolean, tableUtils: BaseTableUtils, constantsProvider: ConstantNameProvider) : Unit = { + def analyzeGroupBy(groupByConf: api.GroupBy, startDate: String, endDate: String, enableHitterAnalysis: Boolean, tableUtils: BaseTableUtils, constantsProvider: ConstantNameProvider): Unit = { logger.info(s"Analyzing GroupBy: ${groupByConf.metaData.name}") Constants.initConstantNameProvider(constantsProvider) val analyzer = new Analyzer(tableUtils, groupByConf, startDate, endDate, enableHitter = enableHitterAnalysis) @@ -190,15 +198,15 @@ object PySparkUtils { /** * Helper function to analyze a Join * - * @param joinConf api.Join Chronon scala Join API object - * @param startDate start date for the join - * @param endDate end date for the join + * @param joinConf api.Join Chronon scala Join API object + * @param startDate start date for the join + * @param endDate end date for the join * @param enableHitterAnalysis if true we will perform an analysis of what hot keys may be present - * @param tableUtils TableUtils this will be used to perform ops against our data sources - * @param constantsProvider ConstantsProvider must be set from the Scala side. Doing so from PySpark will not properly set it. + * @param tableUtils TableUtils this will be used to perform ops against our data sources + * @param constantsProvider ConstantsProvider must be set from the Scala side. Doing so from PySpark will not properly set it. * @return DataFrame */ - def analyzeJoin(joinConf: api.Join, startDate: String, endDate: String, enableHitterAnalysis: Boolean, tableUtils: BaseTableUtils, constantsProvider: ConstantNameProvider) : Unit = { + def analyzeJoin(joinConf: api.Join, startDate: String, endDate: String, enableHitterAnalysis: Boolean, tableUtils: BaseTableUtils, constantsProvider: ConstantNameProvider): Unit = { logger.info(s"Analyzing Join: ${joinConf.metaData.name}") Constants.initConstantNameProvider(constantsProvider) val analyzer = new Analyzer(tableUtils, joinConf, startDate, endDate, enableHitter = enableHitterAnalysis) diff --git a/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksConstantsNameProvider.scala b/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksConstantsNameProvider.scala index 8535f970eb..8ef8f49b52 100644 --- a/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksConstantsNameProvider.scala +++ b/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksConstantsNameProvider.scala @@ -10,7 +10,10 @@ import ai.chronon.api.{ConstantNameProvider, PartitionSpec} */ class DatabricksConstantsNameProvider extends ConstantNameProvider with Serializable { override def TimeColumn: String = "_internal_time_column" + override def DatePartitionColumn: String = "day" + override def HourPartitionColumn: String = "hr" + override def Partition: PartitionSpec = PartitionSpec(format = "yyyyMMdd", spanMillis = WindowUtils.Day.millis) } diff --git a/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala b/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala index 3e52a2f1ed..4d095d8961 100644 --- a/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala @@ -13,4 +13,4 @@ import org.apache.spark.sql.{Column, SparkSession} * * @param sparkSession The Spark session used for table operations. */ -case class DatabricksTableUtils (override val sparkSession: SparkSession) extends BaseTableUtils +case class DatabricksTableUtils(override val sparkSession: SparkSession) extends BaseTableUtils From 54c1828fcd27d1b24d6808df95357bbdf3250c79 Mon Sep 17 00:00:00 2001 From: camweston-stripe <116691078+camweston-stripe@users.noreply.github.com> Date: Mon, 28 Apr 2025 12:14:06 -0700 Subject: [PATCH 21/31] Remove validate functions as the spark validator does not exist is OSS repo. Also remove constants provider for same reason --- .../scala/ai/chronon/spark/PySparkUtils.scala | 74 ++++--------------- 1 file changed, 13 insertions(+), 61 deletions(-) diff --git a/spark/src/main/scala/ai/chronon/spark/PySparkUtils.scala b/spark/src/main/scala/ai/chronon/spark/PySparkUtils.scala index f1ce507a49..4b99237cda 100644 --- a/spark/src/main/scala/ai/chronon/spark/PySparkUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/PySparkUtils.scala @@ -3,7 +3,7 @@ package ai.chronon.spark import ai.chronon.aggregator.windowing.{FiveMinuteResolution, Resolution} import ai.chronon.api import ai.chronon.api.Extensions.MetadataOps -import ai.chronon.api.{ConstantNameProvider, Constants, ThriftJsonCodec} +import ai.chronon.api.ThriftJsonCodec import org.apache.spark.sql.DataFrame import org.slf4j.LoggerFactory @@ -86,16 +86,14 @@ object PySparkUtils { /** * Helper function to allow a user to execute a Group By. * - * @param groupByConf api.GroupBy Chronon scala GroupBy API object - * @param endDate str this represents the last date we will perform the aggregation for - * @param stepDays int this will determine how we chunk filling the missing partitions - * @param tableUtils TableUtils this will be used to perform ops against our data sources - * @param constantsProvider ConstantsProvider must be set from the Scala side. Doing so from PySpark will not properly set it. + * @param groupByConf api.GroupBy Chronon scala GroupBy API object + * @param endDate str this represents the last date we will perform the aggregation for + * @param stepDays int this will determine how we chunk filling the missing partitions + * @param tableUtils TableUtils this will be used to perform ops against our data sources * @return DataFrame */ - def runGroupBy(groupByConf: api.GroupBy, endDate: String, stepDays: Option[Int], tableUtils: BaseTableUtils, constantsProvider: ConstantNameProvider): DataFrame = { + def runGroupBy(groupByConf: api.GroupBy, endDate: String, stepDays: Option[Int], tableUtils: BaseTableUtils): DataFrame = { logger.info(s"Executing GroupBy: ${groupByConf.metaData.name}") - Constants.initConstantNameProvider(constantsProvider) GroupBy.computeBackfill( groupByConf, endDate, @@ -109,11 +107,10 @@ object PySparkUtils { /** * Helper function to allow a user to execute a Join. * - * @param joinConf api.Join Chronon scala Join API object - * @param endDate str this represents the last date we will perform the Join for - * @param stepDays int this will determine how we chunk filling the missing partitions - * @param tableUtils TableUtils this will be used to perform ops against our data sources - * @param constantsProvider ConstantsProvider must be set from the Scala side. Doing so from PySpark will not properly set it. + * @param joinConf api.Join Chronon scala Join API object + * @param endDate str this represents the last date we will perform the Join for + * @param stepDays int this will determine how we chunk filling the missing partitions + * @param tableUtils TableUtils this will be used to perform ops against our data sources * @return DataFrame */ def runJoin(joinConf: api.Join, @@ -121,11 +118,9 @@ object PySparkUtils { stepDays: Option[Int], skipFirstHole: Boolean, sampleNumOfRows: Option[Int], - tableUtils: BaseTableUtils, - constantsProvider: ConstantNameProvider + tableUtils: BaseTableUtils ): DataFrame = { logger.info(s"Executing Join ${joinConf.metaData.name}") - Constants.initConstantNameProvider(constantsProvider) val join = new Join( joinConf, endDate, @@ -137,45 +132,6 @@ object PySparkUtils { resultDf } - /** - * Helper function to validate a GroupBy - * - * @param groupByConf api.GroupBy Chronon scala GroupBy API object - * @param startDate start date for the group by - * @param endDate end date for the group by - * @param tableUtils TableUtils this will be used to perform ops against our data sources - * @param constantsProvider ConstantsProvider must be set from the Scala side. Doing so from PySpark will not properly set it. - * @return DataFrame - */ - def validateGroupBy(groupByConf: api.GroupBy, startDate: String, endDate: String, tableUtils: BaseTableUtils, constantsProvider: ConstantNameProvider): List[String] = { - logger.info(s"Validating GroupBy ${groupByConf.metaData.name}") - Constants.initConstantNameProvider(constantsProvider) - val validator = new Validator(tableUtils, groupByConf, startDate, endDate) - val result = validator.validateGroupBy(groupByConf) - logger.info(s"Finished validating GroupBy ${groupByConf.metaData.name}") - result - } - - - /** - * Helper function to validate a Join - * - * @param joinConf api.Join Chronon scala Join API object - * @param startDate start date for the join - * @param endDate end date for the join - * @param tableUtils TableUtils this will be used to perform ops against our data sources - * @param constantsProvider ConstantsProvider must be set from the Scala side. Doing so from PySpark will not properly set it. - * @return DataFrame - */ - def validateJoin(joinConf: api.Join, startDate: String, endDate: String, tableUtils: BaseTableUtils, constantsProvider: ConstantNameProvider): List[String] = { - logger.info(s"Validating Join: ${joinConf.metaData.name}") - Constants.initConstantNameProvider(constantsProvider) - val validator = new Validator(tableUtils, joinConf, startDate, endDate) - val result = validator.validateJoin(joinConf) - logger.info(s"Finished validating Join: ${joinConf.metaData.name}") - result - } - /** * Helper function to analyze a GroupBy * @@ -184,11 +140,9 @@ object PySparkUtils { * @param endDate end date for the group by * @param enableHitterAnalysis if true we will perform an analysis of what hot keys may be present * @param tableUtils TableUtils this will be used to perform ops against our data sources - * @param constantsProvider ConstantsProvider must be set from the Scala side. Doing so from PySpark will not properly set it. */ - def analyzeGroupBy(groupByConf: api.GroupBy, startDate: String, endDate: String, enableHitterAnalysis: Boolean, tableUtils: BaseTableUtils, constantsProvider: ConstantNameProvider): Unit = { + def analyzeGroupBy(groupByConf: api.GroupBy, startDate: String, endDate: String, enableHitterAnalysis: Boolean, tableUtils: BaseTableUtils): Unit = { logger.info(s"Analyzing GroupBy: ${groupByConf.metaData.name}") - Constants.initConstantNameProvider(constantsProvider) val analyzer = new Analyzer(tableUtils, groupByConf, startDate, endDate, enableHitter = enableHitterAnalysis) analyzer.analyzeGroupBy(groupByConf, enableHitter = enableHitterAnalysis) logger.info(s"Finished analyzing GroupBy: ${groupByConf.metaData.name}") @@ -203,12 +157,10 @@ object PySparkUtils { * @param endDate end date for the join * @param enableHitterAnalysis if true we will perform an analysis of what hot keys may be present * @param tableUtils TableUtils this will be used to perform ops against our data sources - * @param constantsProvider ConstantsProvider must be set from the Scala side. Doing so from PySpark will not properly set it. * @return DataFrame */ - def analyzeJoin(joinConf: api.Join, startDate: String, endDate: String, enableHitterAnalysis: Boolean, tableUtils: BaseTableUtils, constantsProvider: ConstantNameProvider): Unit = { + def analyzeJoin(joinConf: api.Join, startDate: String, endDate: String, enableHitterAnalysis: Boolean, tableUtils: BaseTableUtils): Unit = { logger.info(s"Analyzing Join: ${joinConf.metaData.name}") - Constants.initConstantNameProvider(constantsProvider) val analyzer = new Analyzer(tableUtils, joinConf, startDate, endDate, enableHitter = enableHitterAnalysis) analyzer.analyzeJoin(joinConf, enableHitter = enableHitterAnalysis) logger.info(s"Finished analyzing Join: ${joinConf.metaData.name}") From c3001973c02e65514e1e48499336b973375e8704 Mon Sep 17 00:00:00 2001 From: camweston-stripe <116691078+camweston-stripe@users.noreply.github.com> Date: Mon, 28 Apr 2025 12:22:42 -0700 Subject: [PATCH 22/31] Remove constants provider + validate logic from python api --- api/py/ai/chronon/pyspark/executables.py | 160 +----------------- .../DatabricksConstantsNameProvider.scala | 19 --- 2 files changed, 4 insertions(+), 175 deletions(-) delete mode 100644 spark/src/main/scala/ai/chronon/spark/databricks/DatabricksConstantsNameProvider.scala diff --git a/api/py/ai/chronon/pyspark/executables.py b/api/py/ai/chronon/pyspark/executables.py index d750228a77..bc1b7cc3de 100644 --- a/api/py/ai/chronon/pyspark/executables.py +++ b/api/py/ai/chronon/pyspark/executables.py @@ -331,8 +331,7 @@ def run( end_date, self.jvm.ai.chronon.spark.PySparkUtils.getIntOptional( str(step_days)), - self.platform.get_table_utils(), - self.platform.get_constants_provider() + self.platform.get_table_utils() ) ) @@ -397,8 +396,7 @@ def analyze( start_date, end_date, enable_hitter_analysis, - self.platform.get_table_utils(), - self.platform.get_constants_provider() + self.platform.get_table_utils() ) self.platform.end_log_capture(log_token) self.platform.log_operation( @@ -411,65 +409,6 @@ def analyze( ) raise e - def validate( - self, - start_date: str | None = None, - end_date: str | None = None - ) -> None: - """ - Validate the GroupBy object. - - Args: - start_date: Start date for validation (format: YYYYMMDD) - end_date: End date for validation (format: YYYYMMDD) - """ - start_date = start_date or self.default_start_date - end_date = end_date or self.default_end_date - - self.platform.log_operation( - f"Validating GroupBy {self.obj.metaData.name} from " - f"{start_date} to {end_date}" - ) - - # Prepare GroupBy for validation - group_by_to_validate = copy.deepcopy(self.obj) - - # Update sources with correct dates - group_by_to_validate = self._update_source_dates_for_group_by( - group_by_to_validate, start_date, end_date - ) - - # Start log capture just before executing JVM calls - log_token = self.platform.start_log_capture( - f"Validate GroupBy: {self.obj.metaData.name}" - ) - - try: - # Convert to Java GroupBy - java_group_by: JavaObject = self.group_by_to_java(group_by_to_validate) - # Validate GroupBy - errors_list: JavaObject = self.jvm.ai.chronon.spark.PySparkUtils.validateGroupBy( - java_group_by, - start_date, - end_date, - self.platform.get_table_utils(), - self.platform.get_constants_provider() - ) - - self.platform.end_log_capture(log_token) - self.platform.handle_validation_errors( - errors_list, f"GroupBy {self.obj.metaData.name}" - ) - self.platform.log_operation( - f"Validation for GroupBy {self.obj.metaData.name} has completed" - ) - except Exception as e: - self.platform.log_operation( - f"Validation failed for GroupBy {self.obj.metaData.name}: {str(e)}" - ) - self.platform.end_log_capture(log_token) - raise e - class JoinExecutable(PySparkExecutable[Join], ABC): """Interface for executing Join objects""" @@ -574,8 +513,7 @@ def run( self.jvm.ai.chronon.spark.PySparkUtils.getIntOptional( None if not sample_num_of_rows else str(sample_num_of_rows) ), - self.platform.get_table_utils(), - self.platform.get_constants_provider() + self.platform.get_table_utils() ) result_df = DataFrame(result_df_scala, self.spark) @@ -642,8 +580,7 @@ def analyze( start_date, end_date, enable_hitter_analysis, - self.platform.get_table_utils(), - self.platform.get_constants_provider() + self.platform.get_table_utils() ) self.platform.end_log_capture(log_token) self.platform.log_operation( @@ -657,69 +594,6 @@ def analyze( ) raise e - def validate( - self, - start_date: str | None = None, - end_date: str | None = None - ) -> None: - """ - Validate the Join object. - - Args: - start_date: Start date for validation (format: YYYYMMDD) - end_date: End date for validation (format: YYYYMMDD) - """ - start_date: str = start_date or self.default_start_date - end_date: str = end_date or self.default_end_date - - self.platform.log_operation( - f"Validating Join {self.obj.metaData.name} from " - f"{start_date} to {end_date}" - ) - - # Prepare Join for validation - join_to_validate: Join = copy.deepcopy(self.obj) - join_to_validate.left = self._update_source_dates( - join_to_validate.left, start_date, end_date - ) - - # Update join parts sources - join_to_validate.joinParts = self._update_source_dates_for_join_parts( - join_to_validate.joinParts, start_date, end_date - ) - - # Start log capture just before executing JVM calls - log_token = self.platform.start_log_capture( - f"Validate Join: {self.obj.metaData.name}" - ) - - try: - # Convert to Java Join - java_join: JavaObject = self.join_to_java(join_to_validate) - # Validate Join - errors_list: JavaObject = self.jvm.ai.chronon.spark.PySparkUtils.validateJoin( - java_join, - start_date, - end_date, - self.platform.get_table_utils(), - self.platform.get_constants_provider() - ) - - self.platform.end_log_capture(log_token) - # Handle validation errors - self.platform.handle_validation_errors( - errors_list, f"Join {self.obj.metaData.name}" - ) - self.platform.log_operation( - f"Validation for Join {self.obj.metaData.name} has completed" - ) - except Exception as e: - self.platform.end_log_capture(log_token) - self.platform.log_operation( - f"Validation failed for Join {self.obj.metaData.name}: {str(e)}" - ) - raise e - class PlatformInterface(ABC): """ @@ -741,16 +615,6 @@ def __init__(self, spark: SparkSession) -> None: self.java_spark_session = spark._jsparkSession self.register_udfs() - @abstractmethod - def get_constants_provider(self) -> JavaObject: - """ - Get the platform-specific constants provider. - - Returns: - A JavaObject representing the constants provider - """ - pass - @abstractmethod def get_table_utils(self) -> JavaObject: """ @@ -836,22 +700,6 @@ def drop_table_if_exists(self, table_name: str) -> None: """ _ = self.spark.sql(f"DROP TABLE IF EXISTS {table_name}") - def handle_validation_errors(self, errors: JavaObject, object_name: str) -> None: - """ - Handle validation errors. - - Args: - errors: Platform-specific validation errors - object_name: Name of the object being validated - """ - if errors.length() > 0: - self.log_operation( - message=f"Validation failed for {object_name} " + - "with the following errors:" - ) - self.log_operation(message=str(errors)) - else: - self.log_operation(message=f"Validation passed for {object_name}.") def set_metadata( self, diff --git a/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksConstantsNameProvider.scala b/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksConstantsNameProvider.scala deleted file mode 100644 index 8ef8f49b52..0000000000 --- a/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksConstantsNameProvider.scala +++ /dev/null @@ -1,19 +0,0 @@ -package ai.chronon.spark.databricks - -import ai.chronon.api.Extensions.{WindowOps, WindowUtils} -import ai.chronon.api.{ConstantNameProvider, PartitionSpec} - -/** - * DatabricksConstantsNameProvider provides JVM constants used in our Databricks integration. - * If you need any specific functionality pertaining to your Databricks JVM execution, - * you can implement it here. - */ -class DatabricksConstantsNameProvider extends ConstantNameProvider with Serializable { - override def TimeColumn: String = "_internal_time_column" - - override def DatePartitionColumn: String = "day" - - override def HourPartitionColumn: String = "hr" - - override def Partition: PartitionSpec = PartitionSpec(format = "yyyyMMdd", spanMillis = WindowUtils.Day.millis) -} From d7fd828170a0490127c19a7accf1691bc01e47d2 Mon Sep 17 00:00:00 2001 From: camweston-stripe <116691078+camweston-stripe@users.noreply.github.com> Date: Mon, 28 Apr 2025 12:28:49 -0700 Subject: [PATCH 23/31] Adjust table utils logic --- .../ai/chronon/spark/databricks/DatabricksTableUtils.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala b/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala index 4d095d8961..f1d78d6be6 100644 --- a/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala @@ -1,7 +1,7 @@ package ai.chronon.spark.databricks import ai.chronon.api.Constants -import ai.chronon.spark.BaseTableUtils +import ai.chronon.spark.TableUtils.TableUtils import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{Column, SparkSession} @@ -13,4 +13,4 @@ import org.apache.spark.sql.{Column, SparkSession} * * @param sparkSession The Spark session used for table operations. */ -case class DatabricksTableUtils(override val sparkSession: SparkSession) extends BaseTableUtils +case class DatabricksTableUtils(sparkSession: SparkSession) extends TableUtils(sparkSession) From 695c6ef58ac8069fcfc7a9cc43a43144a3a8da41 Mon Sep 17 00:00:00 2001 From: camweston-stripe <116691078+camweston-stripe@users.noreply.github.com> Date: Mon, 28 Apr 2025 12:35:31 -0700 Subject: [PATCH 24/31] Remove unused imports from databricks table utils --- .../ai/chronon/spark/databricks/DatabricksTableUtils.scala | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala b/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala index f1d78d6be6..fe276ed3de 100644 --- a/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala @@ -1,10 +1,6 @@ package ai.chronon.spark.databricks -import ai.chronon.api.Constants -import ai.chronon.spark.TableUtils.TableUtils -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{Column, SparkSession} - +import ai.chronon.spark.TableUtils /** * DatabricksTableUtils is the table utils class used in our Databricks integration. From 01d05f7665f392f00b8a96fdc33b6718ebf789ef Mon Sep 17 00:00:00 2001 From: camweston-stripe <116691078+camweston-stripe@users.noreply.github.com> Date: Mon, 28 Apr 2025 12:53:21 -0700 Subject: [PATCH 25/31] Fix table utils naming in pyspark utils --- spark/src/main/scala/ai/chronon/spark/PySparkUtils.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/spark/src/main/scala/ai/chronon/spark/PySparkUtils.scala b/spark/src/main/scala/ai/chronon/spark/PySparkUtils.scala index 4b99237cda..3005ab0ce1 100644 --- a/spark/src/main/scala/ai/chronon/spark/PySparkUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/PySparkUtils.scala @@ -92,7 +92,7 @@ object PySparkUtils { * @param tableUtils TableUtils this will be used to perform ops against our data sources * @return DataFrame */ - def runGroupBy(groupByConf: api.GroupBy, endDate: String, stepDays: Option[Int], tableUtils: BaseTableUtils): DataFrame = { + def runGroupBy(groupByConf: api.GroupBy, endDate: String, stepDays: Option[Int], tableUtils: TableUtils): DataFrame = { logger.info(s"Executing GroupBy: ${groupByConf.metaData.name}") GroupBy.computeBackfill( groupByConf, @@ -118,7 +118,7 @@ object PySparkUtils { stepDays: Option[Int], skipFirstHole: Boolean, sampleNumOfRows: Option[Int], - tableUtils: BaseTableUtils + tableUtils: TableUtils ): DataFrame = { logger.info(s"Executing Join ${joinConf.metaData.name}") val join = new Join( @@ -141,7 +141,7 @@ object PySparkUtils { * @param enableHitterAnalysis if true we will perform an analysis of what hot keys may be present * @param tableUtils TableUtils this will be used to perform ops against our data sources */ - def analyzeGroupBy(groupByConf: api.GroupBy, startDate: String, endDate: String, enableHitterAnalysis: Boolean, tableUtils: BaseTableUtils): Unit = { + def analyzeGroupBy(groupByConf: api.GroupBy, startDate: String, endDate: String, enableHitterAnalysis: Boolean, tableUtils: TableUtils): Unit = { logger.info(s"Analyzing GroupBy: ${groupByConf.metaData.name}") val analyzer = new Analyzer(tableUtils, groupByConf, startDate, endDate, enableHitter = enableHitterAnalysis) analyzer.analyzeGroupBy(groupByConf, enableHitter = enableHitterAnalysis) @@ -159,7 +159,7 @@ object PySparkUtils { * @param tableUtils TableUtils this will be used to perform ops against our data sources * @return DataFrame */ - def analyzeJoin(joinConf: api.Join, startDate: String, endDate: String, enableHitterAnalysis: Boolean, tableUtils: BaseTableUtils): Unit = { + def analyzeJoin(joinConf: api.Join, startDate: String, endDate: String, enableHitterAnalysis: Boolean, tableUtils: TableUtils): Unit = { logger.info(s"Analyzing Join: ${joinConf.metaData.name}") val analyzer = new Analyzer(tableUtils, joinConf, startDate, endDate, enableHitter = enableHitterAnalysis) analyzer.analyzeJoin(joinConf, enableHitter = enableHitterAnalysis) From bb988308bca11e71f3ac32c3afb76133ca3e20f2 Mon Sep 17 00:00:00 2001 From: camweston-stripe <116691078+camweston-stripe@users.noreply.github.com> Date: Mon, 28 Apr 2025 12:58:02 -0700 Subject: [PATCH 26/31] Fix linting in python + scala files --- api/py/ai/chronon/pyspark/executables.py | 1 - .../ai/chronon/spark/databricks/DatabricksTableUtils.scala | 2 ++ 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/api/py/ai/chronon/pyspark/executables.py b/api/py/ai/chronon/pyspark/executables.py index bc1b7cc3de..43de2d43b7 100644 --- a/api/py/ai/chronon/pyspark/executables.py +++ b/api/py/ai/chronon/pyspark/executables.py @@ -700,7 +700,6 @@ def drop_table_if_exists(self, table_name: str) -> None: """ _ = self.spark.sql(f"DROP TABLE IF EXISTS {table_name}") - def set_metadata( self, obj: GroupBy | Join | StagingQuery, diff --git a/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala b/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala index fe276ed3de..1713aec499 100644 --- a/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala @@ -2,6 +2,8 @@ package ai.chronon.spark.databricks import ai.chronon.spark.TableUtils +import org.apache.spark.sql.SparkSession + /** * DatabricksTableUtils is the table utils class used in our Databricks integration. * If you need any specific functionality pertaining to reads/writes for your Databricks setup, From 34e001ecf81f7ddc65365fd3d48815919b849fe9 Mon Sep 17 00:00:00 2001 From: camweston-stripe <116691078+camweston-stripe@users.noreply.github.com> Date: Mon, 28 Apr 2025 13:14:44 -0700 Subject: [PATCH 27/31] Had to remove Databricks table utils as we don't have table utils setup as a trait in the OSS repo. Can bring this back in a future PR --- api/py/ai/chronon/pyspark/databricks.py | 30 ------------------- api/py/ai/chronon/pyspark/executables.py | 21 ++++++------- .../databricks/DatabricksTableUtils.scala | 14 --------- 3 files changed, 11 insertions(+), 54 deletions(-) delete mode 100644 spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala diff --git a/api/py/ai/chronon/pyspark/databricks.py b/api/py/ai/chronon/pyspark/databricks.py index 4d06830365..7193edef12 100644 --- a/api/py/ai/chronon/pyspark/databricks.py +++ b/api/py/ai/chronon/pyspark/databricks.py @@ -35,38 +35,8 @@ def __init__(self, spark: SparkSession): """ super().__init__(spark) self.dbutils: DBUtils = DBUtils(self.spark) - self.constants_provider: JavaObject = self.get_constants_provider() - self.table_utils: JavaObject = self.get_table_utils() self.register_udfs() - @override - def get_constants_provider(self) -> JavaObject: - """ - Get the Databricks constants provider. - - Returns: - A JavaObject representing the constants provider - """ - constants_provider: JavaObject = ( - self.jvm.ai.chronon.spark.databricks.DatabricksConstantsNameProvider() - ) - self.jvm.ai.chronon.api.Constants.initConstantNameProvider( - constants_provider - ) - return constants_provider - - @override - def get_table_utils(self) -> JavaObject: - """ - Get the Databricks table utilities. - - Returns: - A JavaObject representing the table utilities - """ - return self.jvm.ai.chronon.spark.databricks.DatabricksTableUtils( - self.java_spark_session - ) - @override def register_udfs(self) -> None: """Register UDFs for Databricks.""" diff --git a/api/py/ai/chronon/pyspark/executables.py b/api/py/ai/chronon/pyspark/executables.py index 43de2d43b7..dec1942260 100644 --- a/api/py/ai/chronon/pyspark/executables.py +++ b/api/py/ai/chronon/pyspark/executables.py @@ -615,16 +615,6 @@ def __init__(self, spark: SparkSession) -> None: self.java_spark_session = spark._jsparkSession self.register_udfs() - @abstractmethod - def get_table_utils(self) -> JavaObject: - """ - Get the platform-specific table utilities. - - Returns: - A JavaObject representing the table utilities - """ - pass - @abstractmethod def get_executable_join_cls(self) -> type[JoinExecutable]: """ @@ -658,6 +648,17 @@ def end_log_capture(self, capture_token: Any) -> None: """ pass + def get_table_utils(self) -> JavaObject: + """ + Get the table utils class that will be used for read/write operations + on the JVM side. Can be overridden by subclasses to provide + platform-specific implementations. + + Returns: + A JavaObject representing the table utilities + """ + return self.jvm.ai.chronon.spark.TableUtils(self.java_spark_session) + def register_udfs(self) -> None: """ Register UDFs for the self.platform. diff --git a/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala b/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala deleted file mode 100644 index 1713aec499..0000000000 --- a/spark/src/main/scala/ai/chronon/spark/databricks/DatabricksTableUtils.scala +++ /dev/null @@ -1,14 +0,0 @@ -package ai.chronon.spark.databricks - -import ai.chronon.spark.TableUtils - -import org.apache.spark.sql.SparkSession - -/** - * DatabricksTableUtils is the table utils class used in our Databricks integration. - * If you need any specific functionality pertaining to reads/writes for your Databricks setup, - * you can implement it here. - * - * @param sparkSession The Spark session used for table operations. - */ -case class DatabricksTableUtils(sparkSession: SparkSession) extends TableUtils(sparkSession) From b9d13da36ebc8de7290752b4fa59076fa91cf578 Mon Sep 17 00:00:00 2001 From: camweston-stripe <116691078+camweston-stripe@users.noreply.github.com> Date: Mon, 28 Apr 2025 13:16:11 -0700 Subject: [PATCH 28/31] Remove unused imports --- api/py/ai/chronon/pyspark/databricks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/api/py/ai/chronon/pyspark/databricks.py b/api/py/ai/chronon/pyspark/databricks.py index 7193edef12..afd443ba30 100644 --- a/api/py/ai/chronon/pyspark/databricks.py +++ b/api/py/ai/chronon/pyspark/databricks.py @@ -3,7 +3,6 @@ import os from typing import cast -from py4j.java_gateway import JavaObject from pyspark.dbutils import DBUtils from pyspark.sql import SparkSession from typing_extensions import override From 605ca758d83a7cbfdbb6e8ab3eb6212083819599 Mon Sep 17 00:00:00 2001 From: camweston-stripe <116691078+camweston-stripe@users.noreply.github.com> Date: Mon, 28 Apr 2025 13:21:24 -0700 Subject: [PATCH 29/31] Trying to fix linting issue for pyspark utils --- spark/src/main/scala/ai/chronon/spark/PySparkUtils.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/spark/src/main/scala/ai/chronon/spark/PySparkUtils.scala b/spark/src/main/scala/ai/chronon/spark/PySparkUtils.scala index 3005ab0ce1..4ca53326f1 100644 --- a/spark/src/main/scala/ai/chronon/spark/PySparkUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/PySparkUtils.scala @@ -148,7 +148,6 @@ object PySparkUtils { logger.info(s"Finished analyzing GroupBy: ${groupByConf.metaData.name}") } - /** * Helper function to analyze a Join * From 36855256a4d968426a7cb1a130a8f3cf3cfc94b9 Mon Sep 17 00:00:00 2001 From: camweston-stripe <116691078+camweston-stripe@users.noreply.github.com> Date: Mon, 28 Apr 2025 13:26:42 -0700 Subject: [PATCH 30/31] Trying to fix linting issue for pyspark utils --- .../scala/ai/chronon/spark/PySparkUtils.scala | 169 ++++++++++-------- 1 file changed, 90 insertions(+), 79 deletions(-) diff --git a/spark/src/main/scala/ai/chronon/spark/PySparkUtils.scala b/spark/src/main/scala/ai/chronon/spark/PySparkUtils.scala index 4ca53326f1..25ca48167e 100644 --- a/spark/src/main/scala/ai/chronon/spark/PySparkUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/PySparkUtils.scala @@ -11,88 +11,92 @@ object PySparkUtils { @transient lazy val logger = LoggerFactory.getLogger(getClass) /** - * Pyspark has a tough time creating a FiveMinuteResolution via jvm.ai.chronon.aggregator.windowing.FiveMinuteResolution so we provide this helper method - * - * @return FiveMinuteResolution - */ + * Pyspark has a tough time creating a FiveMinuteResolution via jvm.ai.chronon.aggregator.windowing.FiveMinuteResolution so we provide this helper method + * + * @return FiveMinuteResolution + */ def getFiveMinuteResolution: Resolution = FiveMinuteResolution /** - * Creating optionals is difficult to support in Pyspark, so we provide this method as a work around - * - * @param timeRange a time range - * @return Empty time range optional - */ - def getTimeRangeOptional(timeRange: TimeRange): Option[TimeRange] = if (timeRange == null) Option.empty[TimeRange] else Option(timeRange) + * Creating optionals is difficult to support in Pyspark, so we provide this method as a work around + * + * @param timeRange a time range + * @return Empty time range optional + */ + def getTimeRangeOptional(timeRange: TimeRange): Option[TimeRange] = + if (timeRange == null) Option.empty[TimeRange] else Option(timeRange) /** - * Creating optionals is difficult to support in Pyspark, so we provide this method as a work around - * - * @param str a string - * @return String optional - */ + * Creating optionals is difficult to support in Pyspark, so we provide this method as a work around + * + * @param str a string + * @return String optional + */ def getStringOptional(str: String): Option[String] = if (str == null) Option.empty[String] else Option(str) /** - * Creating optionals is difficult to support in Pyspark, so we provide this method as a work around - * Furthermore, ints can't be null in Scala so we need to pass the value in as a str - * - * @param strInt a string - * @return Int optional - */ + * Creating optionals is difficult to support in Pyspark, so we provide this method as a work around + * Furthermore, ints can't be null in Scala so we need to pass the value in as a str + * + * @param strInt a string + * @return Int optional + */ def getIntOptional(strInt: String): Option[Int] = if (strInt == null) Option.empty[Int] else Option(strInt.toInt) /** - * Type parameters are difficult to support in Pyspark, so we provide these helper methods for ThriftJsonCodec.fromJsonStr - * - * @param groupByJson a JSON string representing a group by - * @return Chronon Scala API GroupBy object - */ + * Type parameters are difficult to support in Pyspark, so we provide these helper methods for ThriftJsonCodec.fromJsonStr + * + * @param groupByJson a JSON string representing a group by + * @return Chronon Scala API GroupBy object + */ def parseGroupBy(groupByJson: String): api.GroupBy = { ThriftJsonCodec.fromJsonStr[api.GroupBy](groupByJson, check = true, classOf[api.GroupBy]) } /** - * Type parameters are difficult to support in Pyspark, so we provide these helper methods for ThriftJsonCodec.fromJsonStr - * - * @param joinJson a JSON string representing a join - * @return Chronon Scala API Join object - */ + * Type parameters are difficult to support in Pyspark, so we provide these helper methods for ThriftJsonCodec.fromJsonStr + * + * @param joinJson a JSON string representing a join + * @return Chronon Scala API Join object + */ def parseJoin(joinJson: String): api.Join = { ThriftJsonCodec.fromJsonStr[api.Join](joinJson, check = true, classOf[api.Join]) } /** - * Type parameters are difficult to support in Pyspark, so we provide these helper methods for ThriftJsonCodec.fromJsonStr - * - * @param sourceJson a JSON string representing a source. - * @return Chronon Scala API Source object - */ + * Type parameters are difficult to support in Pyspark, so we provide these helper methods for ThriftJsonCodec.fromJsonStr + * + * @param sourceJson a JSON string representing a source. + * @return Chronon Scala API Source object + */ def parseSource(sourceJson: String): api.Source = { ThriftJsonCodec.fromJsonStr[api.Source](sourceJson, check = true, classOf[api.Source]) } /** - * Helper function to get Temporal or Snapshot Accuracy - * - * @param getTemporal boolean value that will decide if we return temporal or snapshot accuracy . - * @return api.Accuracy - */ + * Helper function to get Temporal or Snapshot Accuracy + * + * @param getTemporal boolean value that will decide if we return temporal or snapshot accuracy . + * @return api.Accuracy + */ def getAccuracy(getTemporal: Boolean): api.Accuracy = { if (getTemporal) api.Accuracy.TEMPORAL else api.Accuracy.SNAPSHOT } /** - * Helper function to allow a user to execute a Group By. - * - * @param groupByConf api.GroupBy Chronon scala GroupBy API object - * @param endDate str this represents the last date we will perform the aggregation for - * @param stepDays int this will determine how we chunk filling the missing partitions - * @param tableUtils TableUtils this will be used to perform ops against our data sources - * @return DataFrame - */ - def runGroupBy(groupByConf: api.GroupBy, endDate: String, stepDays: Option[Int], tableUtils: TableUtils): DataFrame = { + * Helper function to allow a user to execute a Group By. + * + * @param groupByConf api.GroupBy Chronon scala GroupBy API object + * @param endDate str this represents the last date we will perform the aggregation for + * @param stepDays int this will determine how we chunk filling the missing partitions + * @param tableUtils TableUtils this will be used to perform ops against our data sources + * @return DataFrame + */ + def runGroupBy(groupByConf: api.GroupBy, + endDate: String, + stepDays: Option[Int], + tableUtils: TableUtils): DataFrame = { logger.info(s"Executing GroupBy: ${groupByConf.metaData.name}") GroupBy.computeBackfill( groupByConf, @@ -105,21 +109,20 @@ object PySparkUtils { } /** - * Helper function to allow a user to execute a Join. - * - * @param joinConf api.Join Chronon scala Join API object - * @param endDate str this represents the last date we will perform the Join for - * @param stepDays int this will determine how we chunk filling the missing partitions - * @param tableUtils TableUtils this will be used to perform ops against our data sources - * @return DataFrame - */ + * Helper function to allow a user to execute a Join. + * + * @param joinConf api.Join Chronon scala Join API object + * @param endDate str this represents the last date we will perform the Join for + * @param stepDays int this will determine how we chunk filling the missing partitions + * @param tableUtils TableUtils this will be used to perform ops against our data sources + * @return DataFrame + */ def runJoin(joinConf: api.Join, endDate: String, stepDays: Option[Int], skipFirstHole: Boolean, sampleNumOfRows: Option[Int], - tableUtils: TableUtils - ): DataFrame = { + tableUtils: TableUtils): DataFrame = { logger.info(s"Executing Join ${joinConf.metaData.name}") val join = new Join( joinConf, @@ -133,15 +136,19 @@ object PySparkUtils { } /** - * Helper function to analyze a GroupBy - * - * @param groupByConf api.GroupBy Chronon scala GroupBy API object - * @param startDate start date for the group by - * @param endDate end date for the group by - * @param enableHitterAnalysis if true we will perform an analysis of what hot keys may be present - * @param tableUtils TableUtils this will be used to perform ops against our data sources - */ - def analyzeGroupBy(groupByConf: api.GroupBy, startDate: String, endDate: String, enableHitterAnalysis: Boolean, tableUtils: TableUtils): Unit = { + * Helper function to analyze a GroupBy + * + * @param groupByConf api.GroupBy Chronon scala GroupBy API object + * @param startDate start date for the group by + * @param endDate end date for the group by + * @param enableHitterAnalysis if true we will perform an analysis of what hot keys may be present + * @param tableUtils TableUtils this will be used to perform ops against our data sources + */ + def analyzeGroupBy(groupByConf: api.GroupBy, + startDate: String, + endDate: String, + enableHitterAnalysis: Boolean, + tableUtils: TableUtils): Unit = { logger.info(s"Analyzing GroupBy: ${groupByConf.metaData.name}") val analyzer = new Analyzer(tableUtils, groupByConf, startDate, endDate, enableHitter = enableHitterAnalysis) analyzer.analyzeGroupBy(groupByConf, enableHitter = enableHitterAnalysis) @@ -149,16 +156,20 @@ object PySparkUtils { } /** - * Helper function to analyze a Join - * - * @param joinConf api.Join Chronon scala Join API object - * @param startDate start date for the join - * @param endDate end date for the join - * @param enableHitterAnalysis if true we will perform an analysis of what hot keys may be present - * @param tableUtils TableUtils this will be used to perform ops against our data sources - * @return DataFrame - */ - def analyzeJoin(joinConf: api.Join, startDate: String, endDate: String, enableHitterAnalysis: Boolean, tableUtils: TableUtils): Unit = { + * Helper function to analyze a Join + * + * @param joinConf api.Join Chronon scala Join API object + * @param startDate start date for the join + * @param endDate end date for the join + * @param enableHitterAnalysis if true we will perform an analysis of what hot keys may be present + * @param tableUtils TableUtils this will be used to perform ops against our data sources + * @return DataFrame + */ + def analyzeJoin(joinConf: api.Join, + startDate: String, + endDate: String, + enableHitterAnalysis: Boolean, + tableUtils: TableUtils): Unit = { logger.info(s"Analyzing Join: ${joinConf.metaData.name}") val analyzer = new Analyzer(tableUtils, joinConf, startDate, endDate, enableHitter = enableHitterAnalysis) analyzer.analyzeJoin(joinConf, enableHitter = enableHitterAnalysis) From e34a33fae47207be155f05c225d40b0f1e153513 Mon Sep 17 00:00:00 2001 From: camweston-stripe <116691078+camweston-stripe@users.noreply.github.com> Date: Mon, 28 Apr 2025 13:58:39 -0700 Subject: [PATCH 31/31] Update readme --- api/py/ai/chronon/pyspark/README.md | 25 ++----------------------- 1 file changed, 2 insertions(+), 23 deletions(-) diff --git a/api/py/ai/chronon/pyspark/README.md b/api/py/ai/chronon/pyspark/README.md index 804d3772e8..7a0fad2cf6 100644 --- a/api/py/ai/chronon/pyspark/README.md +++ b/api/py/ai/chronon/pyspark/README.md @@ -13,7 +13,7 @@ The Chronon PySpark Interface provides a clean, object-oriented framework for executing Chronon feature definitions directly within a PySpark environment, like Databricks Notebooks. This interface streamlines the developer experience by removing the need to switch between multiple tools, allowing rapid prototyping and iteration of Chronon feature engineering workflows. This library enables users to: -- Run, analyze, and validate GroupBy and Join operations in a type-safe manner +- Run and Analyze GroupBy and Join operations in a type-safe manner - Execute feature computations within notebook environments like Databricks - Implement platform-specific behavior while preserving a consistent interface - Access JVM-based functionality directly from Python code @@ -92,7 +92,7 @@ Two specialized interfaces extend the base executable for different Chronon type - **GroupByExecutable**: Interface for executing GroupBy objects - **JoinExecutable**: Interface for executing Join objects -These interfaces define type-specific behaviors for running, analyzing, and validating features. +These interfaces define type-specific behaviors for running and analyzing features. ### Platform Interface @@ -139,7 +139,6 @@ Concrete implementations for specific notebook environments: ├───────────────────┤ ├───────────────────┤ │ + run() │ │ + run() │ │ + analyze() │ │ + analyze() │ -│ + validate() │ │ + validate() │ └────────┬──────────┘ └────────┬──────────┘ │ │ │ │ @@ -157,15 +156,12 @@ Concrete implementations for specific notebook environments: ├─────────────────────────────┤ │ - spark: SparkSession │ ├─────────────────────────────┤ -│ + get_constants_provider() │ -│ + get_table_utils() │ │ + register_udfs() │ │ + get_executable_join_cls() │ │ + start_log_capture() │ │ + end_log_capture() │ │ + log_operation() │ │ + drop_table_if_exists() │ -│ + handle_validation_errors()│ └───────────┬─────────────────┘ │ │ @@ -173,11 +169,7 @@ Concrete implementations for specific notebook environments: │ DatabricksPlatform │ ├────────────────────────────┤ │ - dbutils: DBUtils │ -│ - constants_provider │ -│ - table_utils │ ├────────────────────────────┤ -│ + get_constants_provider() │ -│ + get_table_utils() │ │ + register_udfs() │ │ + get_executable_join_cls()│ │ + start_log_capture() │ @@ -224,16 +216,6 @@ class JupyterPlatform(PlatformInterface): super().__init__(spark) # Initialize Jupyter-specific components - @override - def get_constants_provider(self) -> JavaObject: - # Return Jupyter-specific constants provider - pass - - @override - def get_table_utils(self) -> JavaObject: - # Return Jupyter-specific table utilities - pass - @override def register_udfs(self) -> None: # Register any necessary UDFs for Jupyter @@ -293,10 +275,7 @@ class JupyterJoin(JoinExecutable): When implementing a platform interface, pay special attention to these methods: -- **get_constants_provider()**: Return a platform-specific implementation of constants -- **get_table_utils()**: Return platform-specific table utilities - **start_log_capture()** and **end_log_capture()**: Implement platform-specific log capturing -- **handle_validation_errors()**: Implement platform-specific error handling ## Setup and Dependencies