diff --git a/daft/daft.pyi b/daft/daft.pyi index dd98bc0e3c..2882d003a8 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -1621,6 +1621,7 @@ class LogicalPlanBuilder: ) -> LogicalPlanBuilder: ... @staticmethod def table_scan(scan_operator: ScanOperatorHandle) -> LogicalPlanBuilder: ... + def with_planning_config(self, daft_planning_config: PyDaftPlanningConfig) -> LogicalPlanBuilder: ... def select(self, to_select: list[PyExpr]) -> LogicalPlanBuilder: ... def with_columns(self, columns: list[PyExpr]) -> LogicalPlanBuilder: ... def exclude(self, to_exclude: list[str]) -> LogicalPlanBuilder: ... diff --git a/daft/logical/builder.py b/daft/logical/builder.py index db40e0a461..44a33ba9da 100644 --- a/daft/logical/builder.py +++ b/daft/logical/builder.py @@ -1,8 +1,10 @@ from __future__ import annotations +import functools import pathlib -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable +from daft.context import get_context from daft.daft import ( CountMode, FileFormat, @@ -27,6 +29,25 @@ ) +def _apply_daft_planning_config_to_initializer(classmethod_func: Callable[..., LogicalPlanBuilder]): + """Decorator to be applied to any @classmethod instantiation method on LogicalPlanBuilder + + This decorator ensures that the current DaftPlanningConfig is applied to the instantiated LogicalPlanBuilder + """ + + @functools.wraps(classmethod_func) + def wrapper(cls: type[LogicalPlanBuilder], *args, **kwargs): + instantiated_logical_plan_builder = classmethod_func(cls, *args, **kwargs) + + # Parametrize the builder with the current DaftPlanningConfig + inner = instantiated_logical_plan_builder._builder + inner = inner.with_planning_config(get_context().daft_planning_config) + + return cls(inner) + + return wrapper + + class LogicalPlanBuilder: """ A logical plan builder for the Daft DataFrame. @@ -87,6 +108,7 @@ def optimize(self) -> LogicalPlanBuilder: return LogicalPlanBuilder(builder) @classmethod + @_apply_daft_planning_config_to_initializer def from_in_memory_scan( cls, partition: PartitionCacheEntry, @@ -106,6 +128,7 @@ def from_in_memory_scan( return cls(builder) @classmethod + @_apply_daft_planning_config_to_initializer def from_tabular_scan( cls, *, diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index ab8f829d93..381abdf4cf 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -29,6 +29,7 @@ use daft_scan::{file_format::FileFormat, PhysicalScanInfo, Pushdowns, ScanOperat use { crate::sink_info::{CatalogInfo, IcebergCatalogInfo}, crate::source_info::InMemoryInfo, + common_daft_config::PyDaftPlanningConfig, daft_core::python::schema::PySchema, daft_dsl::python::PyExpr, daft_scan::python::pylib::ScanOperatorHandle, @@ -551,6 +552,13 @@ impl PyLogicalPlanBuilder { Ok(LogicalPlanBuilder::table_scan(scan_operator.into(), None)?.into()) } + pub fn with_planning_config( + &self, + daft_planning_config: PyDaftPlanningConfig, + ) -> PyResult { + Ok(self.builder.with_config(daft_planning_config.config).into()) + } + pub fn select(&self, to_select: Vec) -> PyResult { Ok(self.builder.select(pyexprs_to_exprs(to_select))?.into()) }