From 9980734c72514bf6e6b0546c848613fef8d33c52 Mon Sep 17 00:00:00 2001 From: Krish Narukulla Date: Sun, 25 May 2025 16:25:26 -0700 Subject: [PATCH 1/5] Notebook support for spark on k8s --- api/py/BUILD.bazel | 11 +- api/py/ai/chronon/pyspark/README.md | 312 +++++++ api/py/ai/chronon/pyspark/__init__.py | 0 api/py/ai/chronon/pyspark/constants.py | 19 + api/py/ai/chronon/pyspark/executables.py | 764 ++++++++++++++++++ api/py/ai/chronon/pyspark/notebooks.py | 164 ++++ api/py/ai/chronon/utils.py | 30 + api/py/requirements/base.in | 1 + api/py/requirements/base.txt | 20 +- requirements.txt | 5 +- requirements_lock.txt | 15 +- .../scala/ai/chronon/spark/PySparkUtils.scala | 179 ++++ 12 files changed, 1501 insertions(+), 19 deletions(-) create mode 100644 api/py/ai/chronon/pyspark/README.md create mode 100644 api/py/ai/chronon/pyspark/__init__.py create mode 100644 api/py/ai/chronon/pyspark/constants.py create mode 100644 api/py/ai/chronon/pyspark/executables.py create mode 100644 api/py/ai/chronon/pyspark/notebooks.py create mode 100644 spark/src/main/scala/ai/chronon/spark/PySparkUtils.scala diff --git a/api/py/BUILD.bazel b/api/py/BUILD.bazel index 9db4056c29..d5d2a95b79 100644 --- a/api/py/BUILD.bazel +++ b/api/py/BUILD.bazel @@ -12,6 +12,7 @@ py_library( deps = [ "//api/thrift:api-models-py", requirement("thrift"), + requirement("pyspark"), ], ) @@ -40,21 +41,23 @@ pytest_suite( ["test/**/*.py"], exclude = ["test/sample/**/*"], ), - data = glob(["test/sample/**/*", - "test/lineage/**/*.sql"]), + data = glob([ + "test/sample/**/*", + "test/lineage/**/*.sql", + ]), env = { "CHRONON_ROOT": "api/py/", }, imports = [ ".", - "test/sample", "chronon/api/thrift", + "test/sample", ], deps = [ "//api/py:api_py", "//api/thrift:api-models-py", requirement("thrift"), requirement("click"), - requirement("sqlglot") + requirement("sqlglot"), ], ) diff --git a/api/py/ai/chronon/pyspark/README.md b/api/py/ai/chronon/pyspark/README.md new file mode 100644 index 0000000000..0e2987efa3 --- /dev/null +++ b/api/py/ai/chronon/pyspark/README.md @@ -0,0 +1,312 @@ +# Chronon Python Interface for PySpark Environments + +## Table of Contents +1. [Introduction](#introduction) +2. [Architecture Overview](#architecture-overview) +3. [Core Components](#core-components) +4. [Flow of Execution](#flow-of-execution) +5. [Extending the Framework](#extending-the-framework) +6. [Setup and Dependencies](#setup-and-dependencies) + +## Introduction + +The Chronon PySpark Interface provides a clean, object-oriented framework for executing Chronon feature definitions directly within a PySpark environment, like Notebooks Notebooks. This interface streamlines the developer experience by removing the need to switch between multiple tools, allowing rapid prototyping and iteration of Chronon feature engineering workflows. + +This library enables users to: +- Run and Analyze GroupBy and Join operations in a type-safe manner +- Execute feature computations within notebook environments like Notebooks +- Implement platform-specific behavior while preserving a consistent interface +- Access JVM-based functionality directly from Python code + +## Architecture Overview + +### The Python-JVM Bridge + +At the core of this implementation is the interaction between Python and the Java Virtual Machine (JVM): + +``` +Python Environment | JVM Environment + | + ┌─────────────────────┐ | ┌─────────────────────┐ + │ Python Thrift Obj │ | │ Scala Thrift Obj │ + │ (GroupBy, Join) │─────┐ | ┌────▶│ (GroupBy, Join) │ + └─────────────────────┘ │ | │ └─────────────────────┘ + │ │ | │ │ + ▼ │ | │ ▼ + ┌─────────────────────┐ │ | │ ┌─────────────────────┐ + │ thrift_simple_json()│ │ | │ │ PySparkUtils.parse │ + └─────────────────────┘ │ | │ └─────────────────────┘ + │ │ | │ │ + ▼ │ | │ ▼ + ┌─────────────────────┐ │ | │ ┌─────────────────────┐ + │ JSON String │─────┼──────┼──────┼────▶│ Java Objects │ + └─────────────────────┘ │ | │ └─────────────────────┘ + │ | │ │ + │ | │ ▼ + ┌─────────────────────┐ │ | │ ┌─────────────────────┐ + │ PySparkExecutable │ │ | │ │ PySparkUtils.run │ + │ .run() │─────┼──────┼──────┼────▶│ GroupBy/Join │ + └─────────────────────┘ │ | │ └─────────────────────┘ + ▲ │ | │ │ + │ │ | │ ▼ + ┌─────────────────────┐ │ Py4J Socket │ ┌─────────────────────┐ + │ Python DataFrame │◀────┴──────┼──────┴─────│ JVM DataFrame │ + └─────────────────────┘ | └─────────────────────┘ + | + | + | + | + | + +``` + +- **Py4J**: Enables Python code to dynamically access Java objects, methods, and fields across the JVM boundary +- **PySpark**: Uses Py4J to communicate with the Spark JVM, translating Python calls into Spark's Java/Scala APIs +- **Thrift Objects**: Chronon features defined as Python thrift objects are converted to Java thrift objects for execution + +This design ensures that Python users can access the full power of Chronon's JVM-based computation engine all from a centralized Python environment. + +## Core Components + +The framework is built around several core abstractions: + +### PySparkExecutable + +An abstract base class that provides common functionality for executing Chronon features via PySpark: + +```python +class PySparkExecutable(Generic[T], ABC): + """ + Abstract base class defining common functionality for executing features via PySpark. + """ +``` + +- Handles initialization with object and SparkSession +- Provides utilities for updating dates in sources and queries +- Manages the execution of underlying join sources + +### Specialized Executables + +Two specialized interfaces extend the base executable for different Chronon types: + +- **GroupByExecutable**: Interface for executing GroupBy objects +- **JoinExecutable**: Interface for executing Join objects + +These interfaces define type-specific behaviors for running and analyzing features. + +### Platform Interface + +A key abstraction that enables platform-specific behavior: + +```python +class PlatformInterface(ABC): + """ + Interface for platform-specific operations. + """ +``` + +This interface defines operations that vary by platform (Notebooks, Jupyter, etc.) and must be implemented by platform-specific classes. + +### Platform-Specific Implementations + +Concrete implementations for specific notebook environments: + +- **NotebooksPlatform**: Implements platform-specific operations for Notebooks +- **NotebooksGroupBy**: Executes GroupBy objects in Notebooks +- **NotebooksJoin**: Executes Join objects in Notebooks + +``` +┌─────────────────────────┐ +│ PySparkExecutable │ +│ (Generic[T], ABC) │ +├─────────────────────────┤ +│ - obj: T │ +│ - spark: SparkSession │ +│ - jvm: JVMView │ +├─────────────────────────┤ +│ + get_platform() │ +│ # _update_query_dates() │ +│ # _update_source_dates()│ +│ # print_with_timestamp()│ +└───────────────┬─────────┘ + │ + ┌───────────┴────────────┐ + │ │ +┌───▼───────────────┐ ┌───▼───────────────┐ +│ GroupByExecutable│ │ JoinExecutable │ +├───────────────────┤ ├───────────────────┤ +│ │ │ │ +├───────────────────┤ ├───────────────────┤ +│ + run() │ │ + run() │ +│ + analyze() │ │ + analyze() │ +└────────┬──────────┘ └────────┬──────────┘ + │ │ + │ │ +┌────────▼──────────┐ ┌────────▼──────────┐ +│ NotebooksGroupBy │ │ NotebooksJoin │ +├───────────────────┤ ├───────────────────┤ +│ │ │ │ +├───────────────────┤ ├───────────────────┤ +│ + get_platform() │ │ + get_platform() │ +└───────────────────┘ └───────────────────┘ + +┌─────────────────────────────┐ +│ PlatformInterface │ +│ (ABC) │ +├─────────────────────────────┤ +│ - spark: SparkSession │ +├─────────────────────────────┤ +│ + register_udfs() │ +│ + get_executable_join_cls() │ +│ + start_log_capture() │ +│ + end_log_capture() │ +│ + log_operation() │ +│ + drop_table_if_exists() │ +└───────────┬─────────────────┘ + │ + │ +┌───────────▼────────────────┐ +│ NotebooksPlatform │ +├────────────────────────────┤ +│ - dbutils: DBUtils │ +├────────────────────────────┤ +│ + register_udfs() │ +│ + get_executable_join_cls()│ +│ + start_log_capture() │ +│ + end_log_capture() │ +│ + get_notebooks_user() │ +└────────────────────────────┘ +``` + +## Flow of Execution + +When a user calls a method like `NotebooksGroupBy(group_by, py_spark_session).run()`, the following sequence occurs: + +1. **Object Preparation**: + - The Python thrift object (GroupBy, Join) is copied and updated with appropriate dates (This interface is meant to be used to run prototypes over smaller date ranges and not full backfills) + - Underlying join sources are executed if needed + +2. **JVM Conversion**: + - The Python thrift object is converted to JSON representation + - The JSON is parsed into a Java thrift object on the JVM side via Py4J + +3. **Execution**: + - The JVM executes the computation using Spark + - Results are captured in a Spark DataFrame on the JVM side + +4. **Result Return**: + - The JVM DataFrame is wrapped in a Python DataFrame object + - The Python DataFrame is returned to the user + +5. **Log Handling**: + - Throughout the process, JVM logs are captured + - Logs are printed in the notebook upon completion or errors + +## Extending the Framework + +### Implementing a New Platform Interface + +To add support for a new notebook environment (e.g., Jupyter), follow these steps: + +1. **Create a new platform implementation**: + +```python +class JupyterPlatform(PlatformInterface): + def __init__(self, spark: SparkSession): + super().__init__(spark) + # Initialize Jupyter-specific components + + @override + def register_udfs(self) -> None: + # Register any necessary UDFs for Jupyter + # Recall that UDFs are registered to the shared spark-sql engine + # So you can register python and or scala udfs and use them on both spark sessions + pass + + @override + def get_executable_join_cls(self) -> type[JoinExecutable]: + # Return the Jupyter-specific join executable class + return JupyterJoin + + @override + def start_log_capture(self, job_name: str) -> Any: + # Start capturing logs in Jupyter + pass + + @override + def end_log_capture(self, capture_token: Any) -> None: + # End log capturing and print the logs in Jupyter + pass +``` + +2. **Create concrete executable implementations**: + +```python +class JupyterGroupBy(GroupByExecutable): + def __init__(self, group_by: GroupBy, spark_session: SparkSession): + super().__init__(group_by, spark_session) + # You can pass Jupyter specific parameters into to set metadata + # that allow you to customize things like: + # - What namespace is written to + # - Table name prefixing (in the Notebooks implementation we prefix the table name with the notebook username) + # - Root dir for where your existing feature defs are if you want to import features that were defined in an IDE + self.obj: GroupBy = self.platform.set_metadata(obj=self.obj) + + @override + def get_platform(self) -> PlatformInterface: + return JupyterPlatform(self.spark) + +class JupyterJoin(JoinExecutable): + def __init__(self, join: Join, spark_session: SparkSession): + super().__init__(join, spark_session) + # You can pass Jupyter specific parameters into to set metadata + # that allow you to customize things like: + # - What namespace is written to + # - Table name prefixing (in the Notebooks implementation we prefix the table name with the notebook username) + # - Root dir for where your existing feature defs are if you want to import features that were defined in an IDE + self.obj: Join = self.platform.set_metadata(obj=self.obj) + + @override + def get_platform(self) -> PlatformInterface: + return JupyterPlatform(self.spark) +``` + +### Key Methods to Override + +When implementing a platform interface, pay special attention to these methods: + +- **start_log_capture()** and **end_log_capture()**: Implement platform-specific log capturing + +## Setup and Dependencies + +### Requirements + +1. **Spark Dependencies**: The Chronon Java/Scala JARs must be included in your Spark cluster + +2. **Python Dependencies**: + - pyspark (tested on both 3.1 and 3.3) + +3. **Log File**: You will need to make sure that your Chronon JVM logs are writting to single file. This is generally platform specific. + +### Example Setup + +Here's a minimal example of setting up and using the Chronon Python interface in a Notebooks notebook. It assumes that you have already included the necessary jars in your cluster dependencies. + +```python +# Import the required modules +from pyspark.sql import SparkSession +from ai.chronon.pyspark.notebooks import NotebooksGroupBy, NotebooksJoin +from ai.chronon.api.ttypes import GroupBy, Join +from ai.chronon.group_by import Aggregation, Operation, Window, TimeUnit + +# Define your GroupBy or Join object +my_group_by = GroupBy(...) + +# Create an executable +executable = NotebooksGroupBy(my_group_by, spark) + +# Run the executable +result_df = executable.run(start_date='20250101', end_date='20250107') +``` + +--- \ No newline at end of file diff --git a/api/py/ai/chronon/pyspark/__init__.py b/api/py/ai/chronon/pyspark/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/py/ai/chronon/pyspark/constants.py b/api/py/ai/chronon/pyspark/constants.py new file mode 100644 index 0000000000..1bdf4fa510 --- /dev/null +++ b/api/py/ai/chronon/pyspark/constants.py @@ -0,0 +1,19 @@ +from __future__ import annotations +from typing import Optional + +# -------------------------------------------------------------------------- +# Company Specific Constants +# -------------------------------------------------------------------------- + +PARTITION_COLUMN_FORMAT: str = '%Y%m%d' + +# -------------------------------------------------------------------------- +# Platform Specific Constants +# -------------------------------------------------------------------------- + +# -------------------------------------------------------------------------- +# Notebooks Constants +# -------------------------------------------------------------------------- +NOTEBOOKS_OUTPUT_NAMESPACE: Optional[str] = None +NOTEBOOKS_JVM_LOG_FILE: str = "/notebooks/chronon_logfile.log" +NOTEBOOKS_ROOT_DIR_FOR_IMPORTED_FEATURES: str = "src" diff --git a/api/py/ai/chronon/pyspark/executables.py b/api/py/ai/chronon/pyspark/executables.py new file mode 100644 index 0000000000..dec1942260 --- /dev/null +++ b/api/py/ai/chronon/pyspark/executables.py @@ -0,0 +1,764 @@ +from __future__ import annotations + +import copy +from abc import ABC, abstractmethod +from datetime import datetime, timedelta +from typing import Any, Generic, TypeVar, cast + +from py4j.java_gateway import JVMView, JavaObject +from pyspark.sql import DataFrame, SparkSession + +from ai.chronon.api.ttypes import ( + GroupBy, Join, JoinPart, JoinSource, Query, Source, StagingQuery +) +from ai.chronon.pyspark.constants import PARTITION_COLUMN_FORMAT +from ai.chronon.repo.serializer import thrift_simple_json +from ai.chronon.utils import ( + __set_name as set_name, get_max_window_for_gb_in_days, output_table_name +) + +# Define type variable for our executables +T = TypeVar('T', GroupBy, Join) + + +class PySparkExecutable(Generic[T], ABC): + """ + Abstract base class defining common functionality for executing features via PySpark. + This class provides shared initialization and utility methods but does not + define abstract execution methods, which are left to specialized subclasses. + """ + + def __init__(self, obj: T, spark_session: SparkSession): + """ + Initialize the executable with an object and SparkSession. + + Args: + obj: The object (GroupBy, Join, StagingQuery) to execute + spark_session: The SparkSession to use for execution + """ + self.obj: T = obj + self.spark: SparkSession = spark_session + self.jvm: JVMView = self.spark._jvm + self.platform: PlatformInterface = self.get_platform() + self.java_spark_session: JavaObject = self.spark._jsparkSession + self.default_start_date: str = ( + datetime.now() - timedelta(days=8) + ).strftime(PARTITION_COLUMN_FORMAT) + self.default_end_date: str = ( + datetime.now() - timedelta(days=1) + ).strftime(PARTITION_COLUMN_FORMAT) + + @abstractmethod + def get_platform(self) -> PlatformInterface: + """ + Get the platform interface for platform-specific operations. + + Returns: + A PlatformInterface instance + """ + pass + + def _update_query_dates( + self, query: Query, start_date: str, end_date: str + ) -> Query: + """ + Update start and end dates of a query. + + Args: + query: The query to update + start_date: The new start date + end_date: The new end date + + Returns: + The updated query + """ + query_copy = copy.deepcopy(query) + query_copy.startPartition = start_date + query_copy.endPartition = end_date + return query_copy + + def _update_source_dates( + self, source: Source, start_date: str, end_date: str + ) -> Source: + """ + Update start and end dates of a source. + + Args: + source: The source to update + start_date: The new start date + end_date: The new end date + + Returns: + The updated source + """ + source_copy = copy.deepcopy(source) + if source_copy.events and source_copy.events.query: + source_copy.events.query = self._update_query_dates( + cast(Query, source_copy.events.query), start_date, end_date) + elif source_copy.entities and source_copy.entities.query: + source_copy.entities.query = self._update_query_dates( + cast(Query, source_copy.entities.query), start_date, end_date) + return source_copy + + def _execute_underlying_join_sources( + self, group_bys: list[GroupBy], start_date: str, end_date: str, step_days: int + ) -> None: + """ + Execute underlying join sources. + + Args: + group_bys: List of GroupBy objects + start_date: Start date for execution + end_date: End date for execution + step_days: Number of days to process in each step + """ + joins_to_execute: list[Join] = [] + join_sources_to_execute_start_dates: dict[str, str] = {} + + for group_by in group_bys: + group_by_join_sources: list[JoinSource] = [ + s.joinSource for s in cast(list[Source], group_by.sources) + if s.joinSource and not s.joinSource.outputTableNameOverride + ] + + if not group_by_join_sources: + continue + + # Recall that records generated by the inner join are input events + # for the outer join. Therefore in order to correctly aggregate the + # outer join, your inner join needs to be run from + # start_date - max_window_for_gb_in_days + max_window_for_gb_in_days: int = get_max_window_for_gb_in_days( + group_by) + + shifted_start_date = ( + datetime.strptime(start_date, PARTITION_COLUMN_FORMAT) - + timedelta(days=max_window_for_gb_in_days) + ).strftime(PARTITION_COLUMN_FORMAT) + + for js in group_by_join_sources: + js_name: str | None = js.join.metaData.name + + if js_name is None: + raise ValueError( + f"Join source {js} does not have a name. " + "Was set_metadata called?" + ) + + if js_name not in join_sources_to_execute_start_dates: + join_sources_to_execute_start_dates[js_name] = shifted_start_date + joins_to_execute.append(js.join) + else: + existing_start_date: str = join_sources_to_execute_start_dates[ + js_name + ] + join_sources_to_execute_start_dates[js_name] = min( + shifted_start_date, existing_start_date + ) + + if not joins_to_execute: + return + + self.platform.log_operation( + f"Executing {len(joins_to_execute)} Join Sources" + ) + + for join in joins_to_execute: + j_start_date: str = join_sources_to_execute_start_dates[ + join.metaData.name + ] + + executable_join_cls: type[JoinExecutable] = ( + self.platform.get_executable_join_cls() + ) + + executable_join = executable_join_cls(join, self.spark) + + self.platform.log_operation( + f"Executing Join Source {join.metaData.name} " + f"from {j_start_date} to {end_date}" + ) + _ = executable_join.run( + start_date=j_start_date, end_date=end_date, step_days=step_days + ) + + output_table_name_for_js: str = output_table_name( + join, full_name=True + ) + self.platform.log_operation( + f"Join Source {join.metaData.name} will be read from " + f"{output_table_name_for_js}" + ) + + self.platform.log_operation("Finished executing Join Sources") + + def print_with_timestamp(self, message: str) -> None: + """Utility to print a message with timestamp.""" + current_utc_time = datetime.utcnow() + time_str = current_utc_time.strftime('[%Y-%m-%d %H:%M:%S UTC]') + print(f'{time_str} {message}') + + def group_by_to_java(self, group_by: GroupBy) -> JavaObject: + """ + Convert GroupBy object to Java representation. + + Args: + group_by: The GroupBy object to convert + + Returns: + Java representation of the GroupBy + """ + json_representation: str = thrift_simple_json(group_by) + java_group_by: JavaObject = self.jvm.ai.chronon.spark.PySparkUtils.parseGroupBy( + json_representation + ) + return java_group_by + + def join_to_java(self, join: Join) -> JavaObject: + """ + Convert Join object to Java representation. + + Args: + join: The Join object to convert + + Returns: + Java representation of the Join + """ + json_representation: str = thrift_simple_json(join) + java_join: JavaObject = self.jvm.ai.chronon.spark.PySparkUtils.parseJoin( + json_representation + ) + return java_join + + +class GroupByExecutable(PySparkExecutable[GroupBy], ABC): + """Interface for executing GroupBy objects""" + + def _update_source_dates_for_group_by( + self, group_by: GroupBy, start_date: str, end_date: str + ) -> GroupBy: + """ + Update start and end dates of sources in GroupBy. + + Args: + group_by: The GroupBy object to update + start_date: The new start date + end_date: The new end date + + Returns: + The updated GroupBy object + """ + if not group_by.sources: + return group_by + + for i, source in enumerate(group_by.sources): + group_by.sources[i] = self._update_source_dates( + source, start_date, end_date + ) + return group_by + + def run( + self, + start_date: str | None = None, + end_date: str | None = None, + step_days: int = 30, + skip_execution_of_underlying_join: bool = False + ) -> DataFrame: + """ + Execute the GroupBy object. + + Args: + start_date: Start date for the execution (format: YYYYMMDD) + end_date: End date for the execution (format: YYYYMMDD) + step_days: Number of days to process in each step + skip_execution_of_underlying_join: Whether to skip execution of + underlying joins + + Returns: + DataFrame with the execution results + """ + start_date: str = start_date or self.default_start_date + end_date: str = end_date or self.default_end_date + + self.platform.log_operation( + f"Executing GroupBy {self.obj.metaData.name} from " + f"{start_date} to {end_date} with step_days {step_days}" + ) + self.platform.log_operation( + f"Skip Execution of Underlying Join Sources: " + f"{skip_execution_of_underlying_join}" + ) + + if not skip_execution_of_underlying_join: + self._execute_underlying_join_sources( + group_bys=[self.obj], + start_date=start_date, + end_date=end_date, + step_days=step_days + ) + + # Prepare GroupBy for execution + group_by_to_execute: GroupBy = copy.deepcopy(self.obj) + group_by_to_execute.backfillStartDate = start_date + + # Update sources with correct dates + group_by_to_execute = self._update_source_dates_for_group_by( + group_by_to_execute, start_date, end_date + ) + + # Get output table name + group_by_output_table: str = output_table_name( + group_by_to_execute, full_name=True + ) + + # GroupBy backfills don't store the semantic hash as a property in + # the table the same way joins do. + # Therefore we drop the backfill table to avoid data quality issues. + self.platform.drop_table_if_exists(table_name=group_by_output_table) + + # Find starting point for log capture just before executing JVM calls + log_token = self.platform.start_log_capture( + f"Run GroupBy: {self.obj.metaData.name}" + ) + + try: + # Convert to Java GroupBy + java_group_by = self.group_by_to_java(group_by_to_execute) + # Execute GroupBy + result_df_scala: JavaObject = ( + self.jvm.ai.chronon.spark.PySparkUtils.runGroupBy( + java_group_by, + end_date, + self.jvm.ai.chronon.spark.PySparkUtils.getIntOptional( + str(step_days)), + self.platform.get_table_utils() + ) + ) + + result_df = DataFrame(result_df_scala, self.spark) + self.platform.end_log_capture(log_token) + self.platform.log_operation( + f"GroupBy {self.obj.metaData.name} executed successfully " + f"and was written to {group_by_output_table}" + ) + return result_df + except Exception as e: + self.platform.end_log_capture(log_token) + self.platform.log_operation( + f"Execution failed for GroupBy {self.obj.metaData.name}: {str(e)}" + ) + raise e + + def analyze( + self, + start_date: str | None = None, + end_date: str | None = None, + enable_hitter_analysis: bool = False + ) -> None: + """ + Analyze the GroupBy object. + + Args: + start_date: Start date for analysis (format: YYYYMMDD) + end_date: End date for analysis (format: YYYYMMDD) + enable_hitter_analysis: Whether to enable hitter analysis + """ + start_date = start_date or self.default_start_date + end_date = end_date or self.default_end_date + + self.platform.log_operation( + f"Analyzing GroupBy {self.obj.metaData.name} from " + f"{start_date} to {end_date}" + ) + self.platform.log_operation( + f"Enable Hitter Analysis: {enable_hitter_analysis}" + ) + + # Prepare GroupBy for analysis + group_by_to_analyze: GroupBy = copy.deepcopy(self.obj) + + # Update sources with correct dates + group_by_to_analyze = self._update_source_dates_for_group_by( + group_by_to_analyze, start_date, end_date + ) + + # Start log capture just before executing JVM calls + log_token = self.platform.start_log_capture( + f"Analyze GroupBy: {self.obj.metaData.name}" + ) + + try: + # Convert to Java GroupBy + java_group_by = self.group_by_to_java(group_by_to_analyze) + # Analyze GroupBy + self.jvm.ai.chronon.spark.PySparkUtils.analyzeGroupBy( + java_group_by, + start_date, + end_date, + enable_hitter_analysis, + self.platform.get_table_utils() + ) + self.platform.end_log_capture(log_token) + self.platform.log_operation( + f"GroupBy {self.obj.metaData.name} analyzed successfully" + ) + except Exception as e: + self.platform.end_log_capture(log_token) + self.platform.log_operation( + f"Analysis failed for GroupBy {self.obj.metaData.name}: {str(e)}" + ) + raise e + + +class JoinExecutable(PySparkExecutable[Join], ABC): + """Interface for executing Join objects""" + + def _update_source_dates_for_join_parts( + self, join_parts: list[JoinPart], start_date: str, end_date: str + ) -> list[JoinPart]: + """ + Update start and end dates of sources in JoinParts. + + Args: + join_parts: List of JoinPart objects + start_date: The new start date + end_date: The new end date + + Returns: + The updated list of JoinPart objects + """ + if not join_parts: + return [] + + for jp in join_parts: + for i, source in enumerate(jp.groupBy.sources): + jp.groupBy.sources[i] = self._update_source_dates( + source, start_date, end_date + ) + return join_parts + + def run( + self, + start_date: str | None = None, + end_date: str | None = None, + step_days: int = 30, + skip_first_hole: bool = False, + sample_num_of_rows: int | None = None, + skip_execution_of_underlying_join: bool = False + ) -> DataFrame: + """ + Execute the Join object with Join-specific parameters. + + Args: + start_date: Start date for the execution (format: YYYYMMDD) + end_date: End date for the execution (format: YYYYMMDD) + step_days: Number of days to process in each step + skip_first_hole: Whether to skip the first hole in the join + sample_num_of_rows: Number of rows to sample (None for all) + skip_execution_of_underlying_join: Whether to skip execution of + underlying joins + + Returns: + DataFrame with the execution results + """ + start_date = start_date or self.default_start_date + end_date = end_date or self.default_end_date + + self.platform.log_operation( + f"Executing Join {self.obj.metaData.name} from " + f"{start_date} to {end_date} with step_days {step_days}" + ) + self.platform.log_operation(f"Skip First Hole: {skip_first_hole}") + self.platform.log_operation(f"Sample Number of Rows: {sample_num_of_rows}") + self.platform.log_operation( + f"Skip Execution of Underlying Join: {skip_execution_of_underlying_join}" + ) + + # Prepare Join for execution + join_to_execute = copy.deepcopy(self.obj) + join_to_execute.left = self._update_source_dates( + join_to_execute.left, start_date, end_date + ) + + if not skip_execution_of_underlying_join and self.obj.joinParts: + self._execute_underlying_join_sources( + [jp.groupBy for jp in join_to_execute.joinParts], + start_date, end_date, step_days + ) + + # Update join parts sources + join_to_execute.joinParts = self._update_source_dates_for_join_parts( + join_to_execute.joinParts, start_date, end_date + ) + + # Get output table name + join_output_table = output_table_name(join_to_execute, full_name=True) + + # Start log capture just before executing JVM calls + log_token = self.platform.start_log_capture( + f"Run Join: {self.obj.metaData.name}" + ) + + try: + # Convert to Java Join + java_join: JavaObject = self.join_to_java(join_to_execute) + # Execute Join + result_df_scala: JavaObject = self.jvm.ai.chronon.spark.PySparkUtils.runJoin( + java_join, + end_date, + self.jvm.ai.chronon.spark.PySparkUtils.getIntOptional( + str(step_days) + ), + skip_first_hole, + self.jvm.ai.chronon.spark.PySparkUtils.getIntOptional( + None if not sample_num_of_rows else str(sample_num_of_rows) + ), + self.platform.get_table_utils() + ) + + result_df = DataFrame(result_df_scala, self.spark) + self.platform.end_log_capture(log_token) + self.platform.log_operation( + f"Join {self.obj.metaData.name} executed successfully and " + f"was written to {join_output_table}" + ) + return result_df + except Exception as e: + self.platform.end_log_capture(log_token) + self.platform.log_operation( + f"Execution failed for Join {self.obj.metaData.name}: {str(e)}" + ) + raise e + + def analyze( + self, + start_date: str | None = None, + end_date: str | None = None, + enable_hitter_analysis: bool = False + ) -> None: + """ + Analyze the Join object. + + Args: + start_date: Start date for analysis (format: YYYYMMDD) + end_date: End date for analysis (format: YYYYMMDD) + enable_hitter_analysis: Whether to enable hitter analysis + """ + start_date: str = start_date or self.default_start_date + end_date: str = end_date or self.default_end_date + + self.platform.log_operation( + f"Analyzing Join {self.obj.metaData.name} from " + f"{start_date} to {end_date}" + ) + self.platform.log_operation( + f"Enable Hitter Analysis: {enable_hitter_analysis}" + ) + + # Prepare Join for analysis + join_to_analyze: Join = copy.deepcopy(self.obj) + join_to_analyze.left = self._update_source_dates( + join_to_analyze.left, start_date, end_date + ) + + # Update join parts sources + join_to_analyze.joinParts = self._update_source_dates_for_join_parts( + join_to_analyze.joinParts, start_date, end_date + ) + + # Start log capture just before executing JVM calls + log_token = self.platform.start_log_capture( + f"Analyze Join: {self.obj.metaData.name}" + ) + + try: + # Convert to Java Join + java_join: JavaObject = self.join_to_java(join_to_analyze) + # Analyze Join + self.jvm.ai.chronon.spark.PySparkUtils.analyzeJoin( + java_join, + start_date, + end_date, + enable_hitter_analysis, + self.platform.get_table_utils() + ) + self.platform.end_log_capture(log_token) + self.platform.log_operation( + f"Join {self.obj.metaData.name} analyzed successfully" + ) + + except Exception as e: + self.platform.end_log_capture(log_token) + self.platform.log_operation( + f"Analysis failed for Join {self.obj.metaData.name}: {str(e)}" + ) + raise e + + +class PlatformInterface(ABC): + """ + Interface for platform-specific operations. + + This class defines operations that vary by platform (Databricks, Jupyter, etc.) + and should be implemented by platform-specific classes. + """ + + def __init__(self, spark: SparkSession) -> None: + """ + Initialize with a SparkSession. + + Args: + spark: The SparkSession to use + """ + self.spark = spark + self.jvm = spark._jvm + self.java_spark_session = spark._jsparkSession + self.register_udfs() + + @abstractmethod + def get_executable_join_cls(self) -> type[JoinExecutable]: + """ + Get the class for executing joins. + + Returns: + The class for executing joins + """ + pass + + @abstractmethod + def start_log_capture(self, job_name: str) -> Any: + """ + Start capturing logs for a job. + + Args: + job_name: The name of the job for log headers + + Returns: + A token representing the capture state (platform-specific) + """ + pass + + @abstractmethod + def end_log_capture(self, capture_token: Any) -> None: + """ + End log capturing and print the logs. + + Args: + capture_token: The token returned from start_log_capture + """ + pass + + def get_table_utils(self) -> JavaObject: + """ + Get the table utils class that will be used for read/write operations + on the JVM side. Can be overridden by subclasses to provide + platform-specific implementations. + + Returns: + A JavaObject representing the table utilities + """ + return self.jvm.ai.chronon.spark.TableUtils(self.java_spark_session) + + def register_udfs(self) -> None: + """ + Register UDFs for the self.platform. + + This method is intentionally left empty but not abstract, as some + platforms may not need to register UDFs. + + Subclasses can override this method to provide platform-specific UDF + registration. + + Pro tip: Both the JVM Spark Session and Python Spark Session use the + same spark-sql engine. You can register Python UDFS and use them in + the JVM, as well as vice-versa. + At Stripe we currently only use Scala UDFs, so we include a JAR of our + UDFs in the cluster and register them via: + + self.jvm.com.stripe.shepherd.udfs.ShepherdUdfRegistry.registerAllUdfs( + self.java_spark_session + ) + """ + pass + + def log_operation(self, message: str) -> None: + """ + Log an operation. + + Args: + message: The message to log + """ + current_utc_time = datetime.utcnow() + time_str = current_utc_time.strftime('[%Y-%m-%d %H:%M:%S UTC]') + print(f'{time_str} {message}') + + def drop_table_if_exists(self, table_name: str) -> None: + """ + Drop a table if it exists. + + Args: + table_name: The name of the table to drop + """ + _ = self.spark.sql(f"DROP TABLE IF EXISTS {table_name}") + + def set_metadata( + self, + obj: GroupBy | Join | StagingQuery, + mod_prefix: str, + name_prefix: str | None = None, + output_namespace: str | None = None + ) -> T: + """ + Set the metadata for the object. + + Args: + obj: The object to set metadata on + mod_prefix: The root directory of where your features exist + name_prefix: Optional prefix to add to the name + output_namespace: Optional namespace to override the set output namespace + + Returns: + The updated object + """ + obj_type: type[GroupBy] | type[Join] = type(obj) + + # Handle object naming + if not obj.metaData.name: + try: + set_name(obj=obj, cls=obj_type, mod_prefix=mod_prefix) + except AttributeError: + raise AttributeError( + "Please provide a name when defining " + + "group_bys/joins/staging_queries adhoc." + ) + # We do this to avoid adding the name prefix multiple times + elif obj.metaData.name and name_prefix: + obj.metaData.name = obj.metaData.name.replace(f"{name_prefix}_", "") + + # Handle nested objects for GroupBy + if obj_type == GroupBy: + for s in obj.sources: + if s.joinSource and not s.joinSource.join.metaData.name: + set_name(obj=s.joinSource.join, cls=Join, mod_prefix=mod_prefix) + # Handle nested objects for Join + elif obj_type == Join: + for jp in obj.joinParts: + for s in jp.groupBy.sources: + if s.joinSource and not s.joinSource.join.metaData.name: + set_name( + obj=s.joinSource.join, cls=Join, mod_prefix=mod_prefix + ) + + # Set output namespace + if output_namespace: + obj.metaData.outputNamespace = output_namespace + + if obj_type == Join: + for jp in obj.joinParts: + jp.groupBy.metaData.outputNamespace = output_namespace + + # Add user prefix to name + if name_prefix: + obj.metaData.name = f"{name_prefix}_{obj.metaData.name}" + + return obj diff --git a/api/py/ai/chronon/pyspark/notebooks.py b/api/py/ai/chronon/pyspark/notebooks.py new file mode 100644 index 0000000000..b9111fedc2 --- /dev/null +++ b/api/py/ai/chronon/pyspark/notebooks.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +try: + from pyspark.dbutils import DBUtils +except ImportError: + class DBUtils: + def __init__(self, spark): + print("Mock DBUtils: Notebooks-specific features will not work.") + def fs(self): + return self + def ls(self, path): + print(f"Mock ls called on path: {path}") + return [] + def mkdirs(self, path): + print(f"Mock mkdirs called on path: {path}") + def put(self, path, content, overwrite): + print(f"Mock put called on path: {path} with content and overwrite={overwrite}") + + +from pyspark.dbutils import DBUtils +from pyspark.sql import SparkSession +from typing_extensions import override + +from ai.chronon.api.ttypes import GroupBy, Join +from ai.chronon.pyspark.constants import ( + # NOTEBOOKS_JVM_LOG_FILE, --> Commented by ONUR + NOTEBOOKS_OUTPUT_NAMESPACE, + NOTEBOOKS_ROOT_DIR_FOR_IMPORTED_FEATURES, +) +from ai.chronon.pyspark.executables import ( + GroupByExecutable, + JoinExecutable, + PlatformInterface, +) + + +class NotebooksPlatform(PlatformInterface): + """ + Notebooks-specific implementation of the platform interface. + """ + + def __init__(self, spark: SparkSession): + """ + Initialize Notebooks-specific components. + + Args: + spark: The SparkSession to use + """ + super().__init__(spark) + self.dbutils: DBUtils = DBUtils(self.spark) + self.register_udfs() + + @override + def register_udfs(self) -> None: + """Register UDFs for Notebooks.""" + pass + + @override + def get_executable_join_cls(self) -> type[JoinExecutable]: + """Get the Notebooks-specific join executable class.""" + return NotebooksJoin + + @override + def start_log_capture(self, job_name: str) -> tuple[int, str]: + """ + Start capturing logs in Notebooks. + + Args: + job_name: The name of the job for log headers + + Returns: + A tuple of (start_position, job_name) + """ + # return (os.path.getsize(NOTEBOOKS_JVM_LOG_FILE), job_name) --> Commented by odincol + return (0, job_name) # --> Refactored by odincol + + @override + def end_log_capture(self, capture_token: tuple[int, str]) -> None: + """ + End log capture and print logs in Notebooks. + + Args: + capture_token: The token returned by start_log_capture + """ + start_position, job_name = capture_token + + print("\n\n", "*" * 10, f" BEGIN LOGS FOR {job_name} ", "*" * 10) + # with open(NOTEBOOKS_JVM_LOG_FILE, "r") as file_handler: --> Commented by odincol + # _ = file_handler.seek(start_position) + # print(file_handler.read()) + print("*" * 10, f" END LOGS FOR {job_name} ", "*" * 10, "\n\n") + + def get_notebooks_user(self) -> str: + """ + Get the current Notebooks user. + + Returns: + The username of the current Notebooks user + """ + #user_email = self.dbutils.notebook.entry_point.getDbutils().notebook( --> Commented by odincol + #).getContext().userName().get() + return "" #user_email.split('@')[0].lower() # --> Refactored by odincol + + +class NotebooksGroupBy(GroupByExecutable): + """Class for executing GroupBy objects in Notebooks.""" + + def __init__(self, group_by: GroupBy, spark_session: SparkSession): + """ + Initialize a GroupBy executor for Notebooks. + + Args: + group_by: The GroupBy object to execute + spark_session: The SparkSession to use + """ + super().__init__(group_by, spark_session) + + self.obj: GroupBy = self.platform.set_metadata( + obj=self.obj, + mod_prefix=NOTEBOOKS_ROOT_DIR_FOR_IMPORTED_FEATURES, + name_prefix=cast(NotebooksPlatform, self.platform).get_notebooks_user(), + output_namespace=NOTEBOOKS_OUTPUT_NAMESPACE + ) + + @override + def get_platform(self) -> PlatformInterface: + """ + Get the platform interface. + + Returns: + The Notebooks platform interface + """ + return NotebooksPlatform(self.spark) + + +class NotebooksJoin(JoinExecutable): + """Class for executing Join objects in Notebooks.""" + + def __init__(self, join: Join, spark_session: SparkSession): + """ + Initialize a Join executor for Notebooks. + + Args: + join: The Join object to execute + spark_session: The SparkSession to use + """ + super().__init__(join, spark_session) + + self.obj: Join = self.platform.set_metadata( + obj=self.obj, + mod_prefix=NOTEBOOKS_ROOT_DIR_FOR_IMPORTED_FEATURES, + name_prefix=cast(NotebooksPlatform, self.platform).get_notebooks_user(), + output_namespace=NOTEBOOKS_OUTPUT_NAMESPACE + ) + + @override + def get_platform(self) -> PlatformInterface: + """ + Get the platform interface. + + Returns: + The Notebooks platform interface + """ + return NotebooksPlatform(self.spark) diff --git a/api/py/ai/chronon/utils.py b/api/py/ai/chronon/utils.py index 870231a458..d00c59a67c 100644 --- a/api/py/ai/chronon/utils.py +++ b/api/py/ai/chronon/utils.py @@ -575,3 +575,33 @@ def get_config_path(join_name: str) -> str: assert "." in join_name, f"Invalid join name: {join_name}" team_name, config_name = join_name.split(".", 1) return f"production/joins/{team_name}/{config_name}" + +def get_max_window_for_gb_in_days(group_by: api.GroupBy) -> int: + result: int = 1 + if group_by.aggregations: + for agg in group_by.aggregations: + for window in agg.windows: + if window.timeUnit == api.TimeUnit.MINUTES: + result = int( + max( + result, + ceil(window.length / 60 * 24), + ) + ) + elif window.timeUnit == api.TimeUnit.HOURS: + result = int( + max( + result, + ceil(window.length / 24), + ) + ) + elif window.timeUnit == api.TimeUnit.DAYS: + result = int( + max(result, window.length) + ) + else: + raise ValueError( + f"Unsupported time unit {window.timeUnit}. " + + "Please add logic above to handle the newly introduced time unit." + ) + return result \ No newline at end of file diff --git a/api/py/requirements/base.in b/api/py/requirements/base.in index 74b7ad51e5..167cc3818b 100644 --- a/api/py/requirements/base.in +++ b/api/py/requirements/base.in @@ -1,3 +1,4 @@ click thrift<0.14 sqlglot +pyspark==3.3.1 \ No newline at end of file diff --git a/api/py/requirements/base.txt b/api/py/requirements/base.txt index e5570ef610..7c2a8d8e20 100644 --- a/api/py/requirements/base.txt +++ b/api/py/requirements/base.txt @@ -1,4 +1,4 @@ -# SHA1:4478faec22832482c18fa37db7c49640e4dc015b +# SHA1:b59a9d5665a7082dea5b0bc9f4d1bc2ec737d9e8 # # This file is autogenerated by pip-compile-multi # To update, run: @@ -6,16 +6,14 @@ # pip-compile-multi # click==8.1.8 - # via -r requirements/base.in -importlib-metadata==6.7.0 - # via click + # via -r base.in +py4j==0.10.9.5 + # via pyspark +pyspark==3.3.1 + # via -r base.in six==1.17.0 # via thrift -sqlglot==26.8.0 - # via -r requirements/base.in +sqlglot==26.16.1 + # via -r base.in thrift==0.13.0 - # via -r requirements/base.in -typing-extensions==4.7.1 - # via importlib-metadata -zipp==3.15.0 - # via importlib-metadata + # via -r base.in diff --git a/requirements.txt b/requirements.txt index fe776ef797..19e2a4f021 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,5 +2,8 @@ click thrift==0.13 pytest twine==6.1.0 -sqlglot==25.17.0 +sqlglot==26.16.1 +py4j==0.10.9.5 +pyspark==3.3.1 + # To update dependency run: bazel run //:pip.update \ No newline at end of file diff --git a/requirements_lock.txt b/requirements_lock.txt index e22d6eb043..ba7e3eb848 100644 --- a/requirements_lock.txt +++ b/requirements_lock.txt @@ -206,12 +206,21 @@ pluggy==1.5.0 \ --hash=sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1 \ --hash=sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669 # via pytest +py4j==0.10.9.5 \ + --hash=sha256:276a4a3c5a2154df1860ef3303a927460e02e97b047dc0a47c1c3fb8cce34db6 \ + --hash=sha256:52d171a6a2b031d8a5d1de6efe451cf4f5baff1a2819aabc3741c8406539ba04 + # via + # -r requirements.txt + # pyspark pygments==2.19.1 \ --hash=sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f \ --hash=sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c # via # readme-renderer # rich +pyspark==3.3.1 \ + --hash=sha256:e99fa7de92be406884bfd831c32b9306a3a99de44cfc39a2eefb6ed07445d5fa + # via -r requirements.txt pytest==8.3.4 \ --hash=sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6 \ --hash=sha256:965370d062bce11e73868e0335abac31b4d3de0e82f4007408d242b4f8610761 @@ -243,9 +252,9 @@ six==1.17.0 \ --hash=sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274 \ --hash=sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81 # via thrift -sqlglot==25.17.0 \ - --hash=sha256:8580475f4ee27032ad00b366b8a1967f3630f119a07ed6da92653adcba7ba731 \ - --hash=sha256:91f3741f815a5e1d1dd157a428268af3eda43632dad56790d5c547be1c0491d0 +sqlglot==26.16.1 \ + --hash=sha256:496cb742da55d491ae0c5b38d84e498362ad17a1eef1009d9b336b108a9ee636 \ + --hash=sha256:cced52b35bebb828722f2f4ae4d677d840470ef348f160945ae0ef3d4e457ef8 # via -r requirements.txt thrift==0.13.0 \ --hash=sha256:9af1c86bf73433afc6010ed376a6c6aca2b54099cc0d61895f640870a9ae7d89 diff --git a/spark/src/main/scala/ai/chronon/spark/PySparkUtils.scala b/spark/src/main/scala/ai/chronon/spark/PySparkUtils.scala new file mode 100644 index 0000000000..25ca48167e --- /dev/null +++ b/spark/src/main/scala/ai/chronon/spark/PySparkUtils.scala @@ -0,0 +1,179 @@ +package ai.chronon.spark + +import ai.chronon.aggregator.windowing.{FiveMinuteResolution, Resolution} +import ai.chronon.api +import ai.chronon.api.Extensions.MetadataOps +import ai.chronon.api.ThriftJsonCodec +import org.apache.spark.sql.DataFrame +import org.slf4j.LoggerFactory + +object PySparkUtils { + @transient lazy val logger = LoggerFactory.getLogger(getClass) + + /** + * Pyspark has a tough time creating a FiveMinuteResolution via jvm.ai.chronon.aggregator.windowing.FiveMinuteResolution so we provide this helper method + * + * @return FiveMinuteResolution + */ + + def getFiveMinuteResolution: Resolution = FiveMinuteResolution + + /** + * Creating optionals is difficult to support in Pyspark, so we provide this method as a work around + * + * @param timeRange a time range + * @return Empty time range optional + */ + def getTimeRangeOptional(timeRange: TimeRange): Option[TimeRange] = + if (timeRange == null) Option.empty[TimeRange] else Option(timeRange) + + /** + * Creating optionals is difficult to support in Pyspark, so we provide this method as a work around + * + * @param str a string + * @return String optional + */ + def getStringOptional(str: String): Option[String] = if (str == null) Option.empty[String] else Option(str) + + /** + * Creating optionals is difficult to support in Pyspark, so we provide this method as a work around + * Furthermore, ints can't be null in Scala so we need to pass the value in as a str + * + * @param strInt a string + * @return Int optional + */ + def getIntOptional(strInt: String): Option[Int] = if (strInt == null) Option.empty[Int] else Option(strInt.toInt) + + /** + * Type parameters are difficult to support in Pyspark, so we provide these helper methods for ThriftJsonCodec.fromJsonStr + * + * @param groupByJson a JSON string representing a group by + * @return Chronon Scala API GroupBy object + */ + def parseGroupBy(groupByJson: String): api.GroupBy = { + ThriftJsonCodec.fromJsonStr[api.GroupBy](groupByJson, check = true, classOf[api.GroupBy]) + } + + /** + * Type parameters are difficult to support in Pyspark, so we provide these helper methods for ThriftJsonCodec.fromJsonStr + * + * @param joinJson a JSON string representing a join + * @return Chronon Scala API Join object + */ + def parseJoin(joinJson: String): api.Join = { + ThriftJsonCodec.fromJsonStr[api.Join](joinJson, check = true, classOf[api.Join]) + } + + /** + * Type parameters are difficult to support in Pyspark, so we provide these helper methods for ThriftJsonCodec.fromJsonStr + * + * @param sourceJson a JSON string representing a source. + * @return Chronon Scala API Source object + */ + def parseSource(sourceJson: String): api.Source = { + ThriftJsonCodec.fromJsonStr[api.Source](sourceJson, check = true, classOf[api.Source]) + } + + /** + * Helper function to get Temporal or Snapshot Accuracy + * + * @param getTemporal boolean value that will decide if we return temporal or snapshot accuracy . + * @return api.Accuracy + */ + def getAccuracy(getTemporal: Boolean): api.Accuracy = { + if (getTemporal) api.Accuracy.TEMPORAL else api.Accuracy.SNAPSHOT + } + + /** + * Helper function to allow a user to execute a Group By. + * + * @param groupByConf api.GroupBy Chronon scala GroupBy API object + * @param endDate str this represents the last date we will perform the aggregation for + * @param stepDays int this will determine how we chunk filling the missing partitions + * @param tableUtils TableUtils this will be used to perform ops against our data sources + * @return DataFrame + */ + def runGroupBy(groupByConf: api.GroupBy, + endDate: String, + stepDays: Option[Int], + tableUtils: TableUtils): DataFrame = { + logger.info(s"Executing GroupBy: ${groupByConf.metaData.name}") + GroupBy.computeBackfill( + groupByConf, + endDate, + tableUtils, + stepDays + ) + logger.info(s"Finished executing GroupBy: ${groupByConf.metaData.name}") + tableUtils.sql(s"SELECT * FROM ${groupByConf.metaData.outputTable}") + } + + /** + * Helper function to allow a user to execute a Join. + * + * @param joinConf api.Join Chronon scala Join API object + * @param endDate str this represents the last date we will perform the Join for + * @param stepDays int this will determine how we chunk filling the missing partitions + * @param tableUtils TableUtils this will be used to perform ops against our data sources + * @return DataFrame + */ + def runJoin(joinConf: api.Join, + endDate: String, + stepDays: Option[Int], + skipFirstHole: Boolean, + sampleNumOfRows: Option[Int], + tableUtils: TableUtils): DataFrame = { + logger.info(s"Executing Join ${joinConf.metaData.name}") + val join = new Join( + joinConf, + endDate, + tableUtils, + skipFirstHole = skipFirstHole + ) + val resultDf = join.computeJoin(stepDays) + logger.info(s"Finished executing Join ${joinConf.metaData.name}") + resultDf + } + + /** + * Helper function to analyze a GroupBy + * + * @param groupByConf api.GroupBy Chronon scala GroupBy API object + * @param startDate start date for the group by + * @param endDate end date for the group by + * @param enableHitterAnalysis if true we will perform an analysis of what hot keys may be present + * @param tableUtils TableUtils this will be used to perform ops against our data sources + */ + def analyzeGroupBy(groupByConf: api.GroupBy, + startDate: String, + endDate: String, + enableHitterAnalysis: Boolean, + tableUtils: TableUtils): Unit = { + logger.info(s"Analyzing GroupBy: ${groupByConf.metaData.name}") + val analyzer = new Analyzer(tableUtils, groupByConf, startDate, endDate, enableHitter = enableHitterAnalysis) + analyzer.analyzeGroupBy(groupByConf, enableHitter = enableHitterAnalysis) + logger.info(s"Finished analyzing GroupBy: ${groupByConf.metaData.name}") + } + + /** + * Helper function to analyze a Join + * + * @param joinConf api.Join Chronon scala Join API object + * @param startDate start date for the join + * @param endDate end date for the join + * @param enableHitterAnalysis if true we will perform an analysis of what hot keys may be present + * @param tableUtils TableUtils this will be used to perform ops against our data sources + * @return DataFrame + */ + def analyzeJoin(joinConf: api.Join, + startDate: String, + endDate: String, + enableHitterAnalysis: Boolean, + tableUtils: TableUtils): Unit = { + logger.info(s"Analyzing Join: ${joinConf.metaData.name}") + val analyzer = new Analyzer(tableUtils, joinConf, startDate, endDate, enableHitter = enableHitterAnalysis) + analyzer.analyzeJoin(joinConf, enableHitter = enableHitterAnalysis) + logger.info(s"Finished analyzing Join: ${joinConf.metaData.name}") + } + +} From ee5ca2ca19b1a7d51182a5c80d44a7776dca244e Mon Sep 17 00:00:00 2001 From: Krish Narukulla Date: Sun, 25 May 2025 16:30:06 -0700 Subject: [PATCH 2/5] spark 3.5.5 version for notebooks --- api/py/requirements/base.in | 2 +- api/py/requirements/base.txt | 2 +- requirements.txt | 4 ++-- requirements_lock.txt | 10 +++++----- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/api/py/requirements/base.in b/api/py/requirements/base.in index 167cc3818b..10bfb55b2d 100644 --- a/api/py/requirements/base.in +++ b/api/py/requirements/base.in @@ -1,4 +1,4 @@ click thrift<0.14 sqlglot -pyspark==3.3.1 \ No newline at end of file +pyspark==3.5.5 \ No newline at end of file diff --git a/api/py/requirements/base.txt b/api/py/requirements/base.txt index 7c2a8d8e20..0616487bd0 100644 --- a/api/py/requirements/base.txt +++ b/api/py/requirements/base.txt @@ -9,7 +9,7 @@ click==8.1.8 # via -r base.in py4j==0.10.9.5 # via pyspark -pyspark==3.3.1 +pyspark==3.5.5 # via -r base.in six==1.17.0 # via thrift diff --git a/requirements.txt b/requirements.txt index 19e2a4f021..e91d412a7e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,7 @@ thrift==0.13 pytest twine==6.1.0 sqlglot==26.16.1 -py4j==0.10.9.5 -pyspark==3.3.1 +py4j==0.10.9.7 +pyspark==3.5.5 # To update dependency run: bazel run //:pip.update \ No newline at end of file diff --git a/requirements_lock.txt b/requirements_lock.txt index ba7e3eb848..9e3be87915 100644 --- a/requirements_lock.txt +++ b/requirements_lock.txt @@ -206,9 +206,9 @@ pluggy==1.5.0 \ --hash=sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1 \ --hash=sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669 # via pytest -py4j==0.10.9.5 \ - --hash=sha256:276a4a3c5a2154df1860ef3303a927460e02e97b047dc0a47c1c3fb8cce34db6 \ - --hash=sha256:52d171a6a2b031d8a5d1de6efe451cf4f5baff1a2819aabc3741c8406539ba04 +py4j==0.10.9.7 \ + --hash=sha256:0b6e5315bb3ada5cf62ac651d107bb2ebc02def3dee9d9548e3baac644ea8dbb \ + --hash=sha256:85defdfd2b2376eb3abf5ca6474b51ab7e0de341c75a02f46dc9b5976f5a5c1b # via # -r requirements.txt # pyspark @@ -218,8 +218,8 @@ pygments==2.19.1 \ # via # readme-renderer # rich -pyspark==3.3.1 \ - --hash=sha256:e99fa7de92be406884bfd831c32b9306a3a99de44cfc39a2eefb6ed07445d5fa +pyspark==3.5.5 \ + --hash=sha256:6effc9ce98edf231f4d683fd14f7270629bf8458c628d6a2620ded4bb34f3cb9 # via -r requirements.txt pytest==8.3.4 \ --hash=sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6 \ From 36de04809768e4fb2eb761dd54b5a5d27a449da0 Mon Sep 17 00:00:00 2001 From: Krish Narukulla Date: Tue, 27 May 2025 14:13:51 -0700 Subject: [PATCH 3/5] code review comments --- api/py/ai/chronon/pyspark/README.md | 19 +++---- api/py/ai/chronon/pyspark/constants.py | 6 +- .../{notebooks.py => jupyter_platform.py} | 57 ++++--------------- 3 files changed, 21 insertions(+), 61 deletions(-) rename api/py/ai/chronon/pyspark/{notebooks.py => jupyter_platform.py} (59%) diff --git a/api/py/ai/chronon/pyspark/README.md b/api/py/ai/chronon/pyspark/README.md index 0e2987efa3..cc8a0e58b9 100644 --- a/api/py/ai/chronon/pyspark/README.md +++ b/api/py/ai/chronon/pyspark/README.md @@ -111,9 +111,9 @@ This interface defines operations that vary by platform (Notebooks, Jupyter, etc Concrete implementations for specific notebook environments: -- **NotebooksPlatform**: Implements platform-specific operations for Notebooks -- **NotebooksGroupBy**: Executes GroupBy objects in Notebooks -- **NotebooksJoin**: Executes Join objects in Notebooks +- **JupyterPlatform**: Implements platform-specific operations for Notebooks +- **JupyterGroupBy**: Executes GroupBy objects in Notebooks +- **JupyterJoin**: Executes Join objects in Notebooks ``` ┌─────────────────────────┐ @@ -143,7 +143,7 @@ Concrete implementations for specific notebook environments: │ │ │ │ ┌────────▼──────────┐ ┌────────▼──────────┐ -│ NotebooksGroupBy │ │ NotebooksJoin │ +│ JupyterGroupBy │ │ JupyterJoin │ ├───────────────────┤ ├───────────────────┤ │ │ │ │ ├───────────────────┤ ├───────────────────┤ @@ -166,21 +166,18 @@ Concrete implementations for specific notebook environments: │ │ ┌───────────▼────────────────┐ -│ NotebooksPlatform │ -├────────────────────────────┤ -│ - dbutils: DBUtils │ +│ JupyterPlatform │ ├────────────────────────────┤ │ + register_udfs() │ │ + get_executable_join_cls()│ │ + start_log_capture() │ │ + end_log_capture() │ -│ + get_notebooks_user() │ └────────────────────────────┘ ``` ## Flow of Execution -When a user calls a method like `NotebooksGroupBy(group_by, py_spark_session).run()`, the following sequence occurs: +When a user calls a method like `JupyterGroupBy(group_by, py_spark_session).run()`, the following sequence occurs: 1. **Object Preparation**: - The Python thrift object (GroupBy, Join) is copied and updated with appropriate dates (This interface is meant to be used to run prototypes over smaller date ranges and not full backfills) @@ -295,7 +292,7 @@ Here's a minimal example of setting up and using the Chronon Python interface in ```python # Import the required modules from pyspark.sql import SparkSession -from ai.chronon.pyspark.notebooks import NotebooksGroupBy, NotebooksJoin +from ai.chronon.pyspark.jupyter_platform import JupyterGroupBy, JupyterJoin from ai.chronon.api.ttypes import GroupBy, Join from ai.chronon.group_by import Aggregation, Operation, Window, TimeUnit @@ -303,7 +300,7 @@ from ai.chronon.group_by import Aggregation, Operation, Window, TimeUnit my_group_by = GroupBy(...) # Create an executable -executable = NotebooksGroupBy(my_group_by, spark) +executable = JupyterGroupBy(my_group_by, spark) # Run the executable result_df = executable.run(start_date='20250101', end_date='20250107') diff --git a/api/py/ai/chronon/pyspark/constants.py b/api/py/ai/chronon/pyspark/constants.py index 1bdf4fa510..7acd8b1ede 100644 --- a/api/py/ai/chronon/pyspark/constants.py +++ b/api/py/ai/chronon/pyspark/constants.py @@ -14,6 +14,6 @@ # -------------------------------------------------------------------------- # Notebooks Constants # -------------------------------------------------------------------------- -NOTEBOOKS_OUTPUT_NAMESPACE: Optional[str] = None -NOTEBOOKS_JVM_LOG_FILE: str = "/notebooks/chronon_logfile.log" -NOTEBOOKS_ROOT_DIR_FOR_IMPORTED_FEATURES: str = "src" +JUPYTER_OUTPUT_NAMESPACE: Optional[str] = None +JUPYTER_JVM_LOG_FILE: str = "/jupyter/chronon_logfile.log" +JUPYTER_ROOT_DIR_FOR_IMPORTED_FEATURES: str = "src" diff --git a/api/py/ai/chronon/pyspark/notebooks.py b/api/py/ai/chronon/pyspark/jupyter_platform.py similarity index 59% rename from api/py/ai/chronon/pyspark/notebooks.py rename to api/py/ai/chronon/pyspark/jupyter_platform.py index b9111fedc2..e550513b12 100644 --- a/api/py/ai/chronon/pyspark/notebooks.py +++ b/api/py/ai/chronon/pyspark/jupyter_platform.py @@ -1,23 +1,7 @@ from __future__ import annotations -try: - from pyspark.dbutils import DBUtils -except ImportError: - class DBUtils: - def __init__(self, spark): - print("Mock DBUtils: Notebooks-specific features will not work.") - def fs(self): - return self - def ls(self, path): - print(f"Mock ls called on path: {path}") - return [] - def mkdirs(self, path): - print(f"Mock mkdirs called on path: {path}") - def put(self, path, content, overwrite): - print(f"Mock put called on path: {path} with content and overwrite={overwrite}") - - -from pyspark.dbutils import DBUtils +from typing import cast + from pyspark.sql import SparkSession from typing_extensions import override @@ -34,7 +18,7 @@ def put(self, path, content, overwrite): ) -class NotebooksPlatform(PlatformInterface): +class JupyterPlatform(PlatformInterface): """ Notebooks-specific implementation of the platform interface. """ @@ -47,7 +31,6 @@ def __init__(self, spark: SparkSession): spark: The SparkSession to use """ super().__init__(spark) - self.dbutils: DBUtils = DBUtils(self.spark) self.register_udfs() @override @@ -58,7 +41,7 @@ def register_udfs(self) -> None: @override def get_executable_join_cls(self) -> type[JoinExecutable]: """Get the Notebooks-specific join executable class.""" - return NotebooksJoin + return JupyterJoin @override def start_log_capture(self, job_name: str) -> tuple[int, str]: @@ -71,8 +54,7 @@ def start_log_capture(self, job_name: str) -> tuple[int, str]: Returns: A tuple of (start_position, job_name) """ - # return (os.path.getsize(NOTEBOOKS_JVM_LOG_FILE), job_name) --> Commented by odincol - return (0, job_name) # --> Refactored by odincol + pass @override def end_log_capture(self, capture_token: tuple[int, str]) -> None: @@ -82,27 +64,10 @@ def end_log_capture(self, capture_token: tuple[int, str]) -> None: Args: capture_token: The token returned by start_log_capture """ - start_position, job_name = capture_token - - print("\n\n", "*" * 10, f" BEGIN LOGS FOR {job_name} ", "*" * 10) - # with open(NOTEBOOKS_JVM_LOG_FILE, "r") as file_handler: --> Commented by odincol - # _ = file_handler.seek(start_position) - # print(file_handler.read()) - print("*" * 10, f" END LOGS FOR {job_name} ", "*" * 10, "\n\n") - - def get_notebooks_user(self) -> str: - """ - Get the current Notebooks user. - - Returns: - The username of the current Notebooks user - """ - #user_email = self.dbutils.notebook.entry_point.getDbutils().notebook( --> Commented by odincol - #).getContext().userName().get() - return "" #user_email.split('@')[0].lower() # --> Refactored by odincol + pass -class NotebooksGroupBy(GroupByExecutable): +class JupyterGroupBy(GroupByExecutable): """Class for executing GroupBy objects in Notebooks.""" def __init__(self, group_by: GroupBy, spark_session: SparkSession): @@ -118,7 +83,6 @@ def __init__(self, group_by: GroupBy, spark_session: SparkSession): self.obj: GroupBy = self.platform.set_metadata( obj=self.obj, mod_prefix=NOTEBOOKS_ROOT_DIR_FOR_IMPORTED_FEATURES, - name_prefix=cast(NotebooksPlatform, self.platform).get_notebooks_user(), output_namespace=NOTEBOOKS_OUTPUT_NAMESPACE ) @@ -130,10 +94,10 @@ def get_platform(self) -> PlatformInterface: Returns: The Notebooks platform interface """ - return NotebooksPlatform(self.spark) + return JupyterPlatform(self.spark) -class NotebooksJoin(JoinExecutable): +class JupyterJoin(JoinExecutable): """Class for executing Join objects in Notebooks.""" def __init__(self, join: Join, spark_session: SparkSession): @@ -149,7 +113,6 @@ def __init__(self, join: Join, spark_session: SparkSession): self.obj: Join = self.platform.set_metadata( obj=self.obj, mod_prefix=NOTEBOOKS_ROOT_DIR_FOR_IMPORTED_FEATURES, - name_prefix=cast(NotebooksPlatform, self.platform).get_notebooks_user(), output_namespace=NOTEBOOKS_OUTPUT_NAMESPACE ) @@ -161,4 +124,4 @@ def get_platform(self) -> PlatformInterface: Returns: The Notebooks platform interface """ - return NotebooksPlatform(self.spark) + return JupyterPlatform(self.spark) From c896829500744d852449271cf377c5d98cda7a36 Mon Sep 17 00:00:00 2001 From: Krish Narukulla Date: Thu, 29 May 2025 17:52:29 -0700 Subject: [PATCH 4/5] code review feedback --- api/py/ai/chronon/.DS_Store | Bin 0 -> 6148 bytes api/py/ai/chronon/pyspark/README.md | 34 +++++--- api/py/ai/chronon/pyspark/constants.py | 1 + api/py/ai/chronon/pyspark/executables.py | 74 +++++++++--------- api/py/ai/chronon/pyspark/jupyter_platform.py | 20 ++--- 5 files changed, 69 insertions(+), 60 deletions(-) create mode 100644 api/py/ai/chronon/.DS_Store diff --git a/api/py/ai/chronon/.DS_Store b/api/py/ai/chronon/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..b1b4bc04cd280c5951a495737cf3b82b54e2be2a GIT binary patch literal 6148 zcmeHK!AiqG5S?wSO({YSiXIod7Hn%N6fdFHA26Z^m735{gE3p0)*ebBSN$RX#P4xt zcPnZYJc-yDn0=GknS^->I~f2F-RYnK-~xbyN?54i@`cbo>53Grr-&%%9@!{NRESs4 z;bEaOrSI;}=gm${wi~TQ zP0l;bdQG+t+lxiT+1WohI_=*j_o;d`3<~_%v}{>i!V?;{mPrJ)}MrME+U9X^D8val11&`-zvOC1ivGsr75 zzzlq3fM!1wm-_$7_w&CF;)NMt2L2}lqSW(yU98Ra)|D!$*GkkIR1%8I4ZfwIp<6M= eQY+p=)q;LW2BK#%H;5h-{t?hL@WKrIDg&<*bx&;o literal 0 HcmV?d00001 diff --git a/api/py/ai/chronon/pyspark/README.md b/api/py/ai/chronon/pyspark/README.md index cc8a0e58b9..5d7c9e4e70 100644 --- a/api/py/ai/chronon/pyspark/README.md +++ b/api/py/ai/chronon/pyspark/README.md @@ -1,6 +1,7 @@ # Chronon Python Interface for PySpark Environments ## Table of Contents + 1. [Introduction](#introduction) 2. [Architecture Overview](#architecture-overview) 3. [Core Components](#core-components) @@ -10,9 +11,13 @@ ## Introduction -The Chronon PySpark Interface provides a clean, object-oriented framework for executing Chronon feature definitions directly within a PySpark environment, like Notebooks Notebooks. This interface streamlines the developer experience by removing the need to switch between multiple tools, allowing rapid prototyping and iteration of Chronon feature engineering workflows. +The Chronon PySpark Interface provides a clean, object-oriented framework for executing Chronon feature definitions +directly within a PySpark environment, like Notebooks Notebooks. This interface streamlines the developer experience by +removing the need to switch between multiple tools, allowing rapid prototyping and iteration of Chronon feature +engineering workflows. This library enables users to: + - Run and Analyze GroupBy and Join operations in a type-safe manner - Execute feature computations within notebook environments like Notebooks - Implement platform-specific behavior while preserving a consistent interface @@ -62,9 +67,11 @@ Python Environment | JVM Environment - **Py4J**: Enables Python code to dynamically access Java objects, methods, and fields across the JVM boundary - **PySpark**: Uses Py4J to communicate with the Spark JVM, translating Python calls into Spark's Java/Scala APIs -- **Thrift Objects**: Chronon features defined as Python thrift objects are converted to Java thrift objects for execution +- **Thrift Objects**: Chronon features defined as Python thrift objects are converted to Java thrift objects for + execution -This design ensures that Python users can access the full power of Chronon's JVM-based computation engine all from a centralized Python environment. +This design ensures that Python users can access the full power of Chronon's JVM-based computation engine all from a +centralized Python environment. ## Core Components @@ -105,7 +112,8 @@ class PlatformInterface(ABC): """ ``` -This interface defines operations that vary by platform (Notebooks, Jupyter, etc.) and must be implemented by platform-specific classes. +This interface defines operations that vary by platform (Notebooks, Jupyter, etc.) and must be implemented by +platform-specific classes. ### Platform-Specific Implementations @@ -180,7 +188,8 @@ Concrete implementations for specific notebook environments: When a user calls a method like `JupyterGroupBy(group_by, py_spark_session).run()`, the following sequence occurs: 1. **Object Preparation**: - - The Python thrift object (GroupBy, Join) is copied and updated with appropriate dates (This interface is meant to be used to run prototypes over smaller date ranges and not full backfills) + - The Python thrift object (GroupBy, Join) is copied and updated with appropriate dates (This interface is meant to + be used to run prototypes over smaller date ranges and not full backfills) - Underlying join sources are executed if needed 2. **JVM Conversion**: @@ -212,24 +221,24 @@ class JupyterPlatform(PlatformInterface): def __init__(self, spark: SparkSession): super().__init__(spark) # Initialize Jupyter-specific components - + @override def register_udfs(self) -> None: # Register any necessary UDFs for Jupyter # Recall that UDFs are registered to the shared spark-sql engine # So you can register python and or scala udfs and use them on both spark sessions pass - + @override def get_executable_join_cls(self) -> type[JoinExecutable]: # Return the Jupyter-specific join executable class return JupyterJoin - + @override def start_log_capture(self, job_name: str) -> Any: # Start capturing logs in Jupyter pass - + @override def end_log_capture(self, capture_token: Any) -> None: # End log capturing and print the logs in Jupyter @@ -253,6 +262,7 @@ class JupyterGroupBy(GroupByExecutable): def get_platform(self) -> PlatformInterface: return JupyterPlatform(self.spark) + class JupyterJoin(JoinExecutable): def __init__(self, join: Join, spark_session: SparkSession): super().__init__(join, spark_session) @@ -283,11 +293,13 @@ When implementing a platform interface, pay special attention to these methods: 2. **Python Dependencies**: - pyspark (tested on both 3.1 and 3.3) -3. **Log File**: You will need to make sure that your Chronon JVM logs are writting to single file. This is generally platform specific. +3. **Log File**: You will need to make sure that your Chronon JVM logs are writting to single file. This is generally + platform specific. ### Example Setup -Here's a minimal example of setting up and using the Chronon Python interface in a Notebooks notebook. It assumes that you have already included the necessary jars in your cluster dependencies. +Here's a minimal example of setting up and using the Chronon Python interface in a Notebooks notebook. It assumes that +you have already included the necessary jars in your cluster dependencies. ```python # Import the required modules diff --git a/api/py/ai/chronon/pyspark/constants.py b/api/py/ai/chronon/pyspark/constants.py index 7acd8b1ede..4c0e6a51a5 100644 --- a/api/py/ai/chronon/pyspark/constants.py +++ b/api/py/ai/chronon/pyspark/constants.py @@ -1,4 +1,5 @@ from __future__ import annotations + from typing import Optional # -------------------------------------------------------------------------- diff --git a/api/py/ai/chronon/pyspark/executables.py b/api/py/ai/chronon/pyspark/executables.py index dec1942260..90edc7ea5e 100644 --- a/api/py/ai/chronon/pyspark/executables.py +++ b/api/py/ai/chronon/pyspark/executables.py @@ -5,13 +5,13 @@ from datetime import datetime, timedelta from typing import Any, Generic, TypeVar, cast -from py4j.java_gateway import JVMView, JavaObject -from pyspark.sql import DataFrame, SparkSession - from ai.chronon.api.ttypes import ( GroupBy, Join, JoinPart, JoinSource, Query, Source, StagingQuery ) from ai.chronon.pyspark.constants import PARTITION_COLUMN_FORMAT +from py4j.java_gateway import JVMView, JavaObject +from pyspark.sql import DataFrame, SparkSession + from ai.chronon.repo.serializer import thrift_simple_json from ai.chronon.utils import ( __set_name as set_name, get_max_window_for_gb_in_days, output_table_name @@ -42,10 +42,10 @@ def __init__(self, obj: T, spark_session: SparkSession): self.platform: PlatformInterface = self.get_platform() self.java_spark_session: JavaObject = self.spark._jsparkSession self.default_start_date: str = ( - datetime.now() - timedelta(days=8) + datetime.now() - timedelta(days=8) ).strftime(PARTITION_COLUMN_FORMAT) self.default_end_date: str = ( - datetime.now() - timedelta(days=1) + datetime.now() - timedelta(days=1) ).strftime(PARTITION_COLUMN_FORMAT) @abstractmethod @@ -59,7 +59,7 @@ def get_platform(self) -> PlatformInterface: pass def _update_query_dates( - self, query: Query, start_date: str, end_date: str + self, query: Query, start_date: str, end_date: str ) -> Query: """ Update start and end dates of a query. @@ -78,7 +78,7 @@ def _update_query_dates( return query_copy def _update_source_dates( - self, source: Source, start_date: str, end_date: str + self, source: Source, start_date: str, end_date: str ) -> Source: """ Update start and end dates of a source. @@ -101,7 +101,7 @@ def _update_source_dates( return source_copy def _execute_underlying_join_sources( - self, group_bys: list[GroupBy], start_date: str, end_date: str, step_days: int + self, group_bys: list[GroupBy], start_date: str, end_date: str, step_days: int ) -> None: """ Execute underlying join sources. @@ -132,8 +132,8 @@ def _execute_underlying_join_sources( group_by) shifted_start_date = ( - datetime.strptime(start_date, PARTITION_COLUMN_FORMAT) - - timedelta(days=max_window_for_gb_in_days) + datetime.strptime(start_date, PARTITION_COLUMN_FORMAT) - + timedelta(days=max_window_for_gb_in_days) ).strftime(PARTITION_COLUMN_FORMAT) for js in group_by_join_sources: @@ -235,7 +235,7 @@ class GroupByExecutable(PySparkExecutable[GroupBy], ABC): """Interface for executing GroupBy objects""" def _update_source_dates_for_group_by( - self, group_by: GroupBy, start_date: str, end_date: str + self, group_by: GroupBy, start_date: str, end_date: str ) -> GroupBy: """ Update start and end dates of sources in GroupBy. @@ -258,11 +258,11 @@ def _update_source_dates_for_group_by( return group_by def run( - self, - start_date: str | None = None, - end_date: str | None = None, - step_days: int = 30, - skip_execution_of_underlying_join: bool = False + self, + start_date: str | None = None, + end_date: str | None = None, + step_days: int = 30, + skip_execution_of_underlying_join: bool = False ) -> DataFrame: """ Execute the GroupBy object. @@ -350,10 +350,10 @@ def run( raise e def analyze( - self, - start_date: str | None = None, - end_date: str | None = None, - enable_hitter_analysis: bool = False + self, + start_date: str | None = None, + end_date: str | None = None, + enable_hitter_analysis: bool = False ) -> None: """ Analyze the GroupBy object. @@ -414,7 +414,7 @@ class JoinExecutable(PySparkExecutable[Join], ABC): """Interface for executing Join objects""" def _update_source_dates_for_join_parts( - self, join_parts: list[JoinPart], start_date: str, end_date: str + self, join_parts: list[JoinPart], start_date: str, end_date: str ) -> list[JoinPart]: """ Update start and end dates of sources in JoinParts. @@ -438,13 +438,13 @@ def _update_source_dates_for_join_parts( return join_parts def run( - self, - start_date: str | None = None, - end_date: str | None = None, - step_days: int = 30, - skip_first_hole: bool = False, - sample_num_of_rows: int | None = None, - skip_execution_of_underlying_join: bool = False + self, + start_date: str | None = None, + end_date: str | None = None, + step_days: int = 30, + skip_first_hole: bool = False, + sample_num_of_rows: int | None = None, + skip_execution_of_underlying_join: bool = False ) -> DataFrame: """ Execute the Join object with Join-specific parameters. @@ -531,10 +531,10 @@ def run( raise e def analyze( - self, - start_date: str | None = None, - end_date: str | None = None, - enable_hitter_analysis: bool = False + self, + start_date: str | None = None, + end_date: str | None = None, + enable_hitter_analysis: bool = False ) -> None: """ Analyze the Join object. @@ -702,11 +702,11 @@ def drop_table_if_exists(self, table_name: str) -> None: _ = self.spark.sql(f"DROP TABLE IF EXISTS {table_name}") def set_metadata( - self, - obj: GroupBy | Join | StagingQuery, - mod_prefix: str, - name_prefix: str | None = None, - output_namespace: str | None = None + self, + obj: GroupBy | Join | StagingQuery, + mod_prefix: str, + name_prefix: str | None = None, + output_namespace: str | None = None ) -> T: """ Set the metadata for the object. diff --git a/api/py/ai/chronon/pyspark/jupyter_platform.py b/api/py/ai/chronon/pyspark/jupyter_platform.py index e550513b12..a9ec32ad43 100644 --- a/api/py/ai/chronon/pyspark/jupyter_platform.py +++ b/api/py/ai/chronon/pyspark/jupyter_platform.py @@ -1,21 +1,17 @@ from __future__ import annotations -from typing import cast - -from pyspark.sql import SparkSession -from typing_extensions import override - from ai.chronon.api.ttypes import GroupBy, Join from ai.chronon.pyspark.constants import ( - # NOTEBOOKS_JVM_LOG_FILE, --> Commented by ONUR - NOTEBOOKS_OUTPUT_NAMESPACE, - NOTEBOOKS_ROOT_DIR_FOR_IMPORTED_FEATURES, + JUPYTER_OUTPUT_NAMESPACE, + JUPYTER_ROOT_DIR_FOR_IMPORTED_FEATURES, ) from ai.chronon.pyspark.executables import ( GroupByExecutable, JoinExecutable, PlatformInterface, ) +from pyspark.sql import SparkSession +from typing_extensions import override class JupyterPlatform(PlatformInterface): @@ -82,8 +78,8 @@ def __init__(self, group_by: GroupBy, spark_session: SparkSession): self.obj: GroupBy = self.platform.set_metadata( obj=self.obj, - mod_prefix=NOTEBOOKS_ROOT_DIR_FOR_IMPORTED_FEATURES, - output_namespace=NOTEBOOKS_OUTPUT_NAMESPACE + mod_prefix=JUPYTER_ROOT_DIR_FOR_IMPORTED_FEATURES, + output_namespace=JUPYTER_OUTPUT_NAMESPACE ) @override @@ -112,8 +108,8 @@ def __init__(self, join: Join, spark_session: SparkSession): self.obj: Join = self.platform.set_metadata( obj=self.obj, - mod_prefix=NOTEBOOKS_ROOT_DIR_FOR_IMPORTED_FEATURES, - output_namespace=NOTEBOOKS_OUTPUT_NAMESPACE + mod_prefix=JUPYTER_ROOT_DIR_FOR_IMPORTED_FEATURES, + output_namespace=JUPYTER_OUTPUT_NAMESPACE ) @override From b24cc1745e6f90619ff246562dd61072068ee82e Mon Sep 17 00:00:00 2001 From: Akshay Thorat Date: Sat, 1 Nov 2025 22:32:46 -0700 Subject: [PATCH 5/5] Fix CI errors: Add missing TableUtils import and fix Python lint issues - Added missing import for ai.chronon.spark.catalog.TableUtils in PySparkUtils.scala - Fixed missing newline at end of utils.py file (W292) These changes fix the following CI failures: - Scala 11/12/13 compilation errors (missing TableUtils type) - Python lint error (no newline at end of file) --- api/py/ai/chronon/utils.py | 4 +++- spark/src/main/scala/ai/chronon/spark/PySparkUtils.scala | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/api/py/ai/chronon/utils.py b/api/py/ai/chronon/utils.py index d55acb67b4..bbe2e8ec09 100644 --- a/api/py/ai/chronon/utils.py +++ b/api/py/ai/chronon/utils.py @@ -23,6 +23,7 @@ from collections.abc import Iterable from dataclasses import dataclass, fields from enum import Enum +from math import ceil from typing import Dict, List, Optional, Union, cast import ai.chronon.api.ttypes as api @@ -577,6 +578,7 @@ def get_config_path(join_name: str) -> str: team_name, config_name = join_name.split(".", 1) return f"production/joins/{team_name}/{config_name}" + def get_max_window_for_gb_in_days(group_by: api.GroupBy) -> int: result: int = 1 if group_by.aggregations: @@ -605,4 +607,4 @@ def get_max_window_for_gb_in_days(group_by: api.GroupBy) -> int: f"Unsupported time unit {window.timeUnit}. " + "Please add logic above to handle the newly introduced time unit." ) - return result \ No newline at end of file + return result diff --git a/spark/src/main/scala/ai/chronon/spark/PySparkUtils.scala b/spark/src/main/scala/ai/chronon/spark/PySparkUtils.scala index 25ca48167e..7af31dd35d 100644 --- a/spark/src/main/scala/ai/chronon/spark/PySparkUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/PySparkUtils.scala @@ -4,6 +4,7 @@ import ai.chronon.aggregator.windowing.{FiveMinuteResolution, Resolution} import ai.chronon.api import ai.chronon.api.Extensions.MetadataOps import ai.chronon.api.ThriftJsonCodec +import ai.chronon.spark.catalog.TableUtils import org.apache.spark.sql.DataFrame import org.slf4j.LoggerFactory