diff --git a/daft/daft.pyi b/daft/daft.pyi index 7f6e793d0a..0c4e043bb5 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -1625,6 +1625,7 @@ class LogicalPlanBuilder: ) -> LogicalPlanBuilder: ... @staticmethod def table_scan(scan_operator: ScanOperatorHandle) -> LogicalPlanBuilder: ... + def with_planning_config(self, daft_planning_config: PyDaftPlanningConfig) -> LogicalPlanBuilder: ... def select(self, to_select: list[PyExpr]) -> LogicalPlanBuilder: ... def with_columns(self, columns: list[PyExpr]) -> LogicalPlanBuilder: ... def exclude(self, to_exclude: list[str]) -> LogicalPlanBuilder: ... diff --git a/daft/logical/builder.py b/daft/logical/builder.py index 89f19458fd..b2717df2f6 100644 --- a/daft/logical/builder.py +++ b/daft/logical/builder.py @@ -1,8 +1,10 @@ from __future__ import annotations +import functools import pathlib -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable +from daft.context import get_context from daft.daft import ( CountMode, FileFormat, @@ -26,6 +28,25 @@ ) +def _apply_daft_planning_config_to_initializer(classmethod_func: Callable[..., LogicalPlanBuilder]): + """Decorator to be applied to any @classmethod instantiation method on LogicalPlanBuilder + + This decorator ensures that the current DaftPlanningConfig is applied to the instantiated LogicalPlanBuilder + """ + + @functools.wraps(classmethod_func) + def wrapper(cls: type[LogicalPlanBuilder], *args, **kwargs): + instantiated_logical_plan_builder = classmethod_func(cls, *args, **kwargs) + + # Parametrize the builder with the current DaftPlanningConfig + inner = instantiated_logical_plan_builder._builder + inner = inner.with_planning_config(get_context().daft_planning_config) + + return cls(inner) + + return wrapper + + class LogicalPlanBuilder: """ A logical plan builder for the Daft DataFrame. @@ -91,6 +112,7 @@ def optimize(self) -> LogicalPlanBuilder: return LogicalPlanBuilder(builder) @classmethod + @_apply_daft_planning_config_to_initializer def from_in_memory_scan( cls, partition: PartitionCacheEntry, @@ -110,6 +132,7 @@ def from_in_memory_scan( return cls(builder) @classmethod + @_apply_daft_planning_config_to_initializer def from_tabular_scan( cls, *, diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index ab96622b2d..9a5719c652 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -30,6 +30,7 @@ use daft_scan::{PhysicalScanInfo, Pushdowns, ScanOperatorRef}; use { crate::sink_info::{CatalogInfo, IcebergCatalogInfo}, crate::source_info::InMemoryInfo, + common_daft_config::PyDaftPlanningConfig, daft_core::python::schema::PySchema, daft_dsl::python::PyExpr, daft_scan::python::pylib::ScanOperatorHandle, @@ -548,6 +549,13 @@ impl PyLogicalPlanBuilder { Ok(LogicalPlanBuilder::table_scan(scan_operator.into(), None)?.into()) } + pub fn with_planning_config( + &self, + daft_planning_config: PyDaftPlanningConfig, + ) -> PyResult { + Ok(self.builder.with_config(daft_planning_config.config).into()) + } + pub fn select(&self, to_select: Vec) -> PyResult { Ok(self.builder.select(pyexprs_to_exprs(to_select))?.into()) }