Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1788,7 +1788,7 @@ class DistributedPhysicalPlan:
def repr_mermaid(self, options: MermaidOptions) -> str: ...

class DistributedPhysicalPlanRunner:
def __init__(self, on_actor: bool) -> None: ...
def __init__(self) -> None: ...
def run_plan(
self, plan: DistributedPhysicalPlan, psets: dict[str, list[RayPartitionRef]]
) -> AsyncIterator[RayPartitionRef]: ...
Expand Down
122 changes: 28 additions & 94 deletions daft/runners/flotilla.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,17 +185,14 @@ def try_autoscale(bundles: list[dict[str, int]]) -> None:
)


class FlotillaRunnerCore:
"""Core functionality for running distributed physical plans with Flotilla.

This class contains the business logic for managing plans and their execution,
separate from the Ray actor wrapper.
"""

def __init__(self, on_actor: bool = False) -> None:
@ray.remote(
num_cpus=0,
)
class RemoteFlotillaRunner:
def __init__(self) -> None:
self.curr_plans: dict[str, DistributedPhysicalPlan] = {}
self.curr_result_gens: dict[str, AsyncIterator[RayPartitionRef]] = {}
self.plan_runner = DistributedPhysicalPlanRunner(on_actor)
self.plan_runner = DistributedPhysicalPlanRunner()

def run_plan(
self,
Expand Down Expand Up @@ -235,58 +232,7 @@ async def get_next_partition(self, plan_id: str) -> RayMaterializedResult | None
return materialized_result


class LocalFlotillaRunner:
"""Local wrapper around FlotillaPlanRunnerCore.

This wrapper provides the same interface as FlotillaPlanRunner but without
Ray actor overhead, useful for local testing or when distributed execution
is not needed.
"""

def __init__(self) -> None:
self.loop = asyncio.new_event_loop()
self.core = self.loop.run_until_complete(self._make_runner())

async def _make_runner(self) -> FlotillaRunnerCore:
return FlotillaRunnerCore(on_actor=False)

def run_plan(
self,
plan: DistributedPhysicalPlan,
partition_sets: dict[str, PartitionSet[ray.ObjectRef]],
) -> None:
self.core.run_plan(plan, partition_sets)

def get_next_partition(self, plan_id: str) -> RayMaterializedResult | None:
"""Synchronous version of get_next_partition that internally uses asyncio."""
return self.loop.run_until_complete(self.core.get_next_partition(plan_id))


@ray.remote(
num_cpus=0,
)
class RemoteFlotillaRunner:
"""Ray actor wrapper around FlotillaPlanRunnerCore.

This actor provides the distributed interface for running plans,
while delegating the actual work to the core class.
"""

def __init__(self) -> None:
self.core = FlotillaRunnerCore(on_actor=True)

def run_plan(
self,
plan: DistributedPhysicalPlan,
partition_sets: dict[str, PartitionSet[ray.ObjectRef]],
) -> None:
self.core.run_plan(plan, partition_sets)

async def get_next_partition(self, plan_id: str) -> RayMaterializedResult | None:
return await self.core.get_next_partition(plan_id)


FLOTILLA_RUNER_NAMESPACE = "daft"
FLOTILLA_RUNNER_NAMESPACE = "daft"
FLOTILLA_RUNNER_NAME = "flotilla-plan-runner"


Expand All @@ -304,43 +250,31 @@ def get_head_node_id() -> str | None:
class FlotillaRunner:
"""FlotillaRunner is a wrapper around FlotillaRunnerCore that provides a Ray actor interface."""

def __init__(self, use_actor: bool = False) -> None:
self.use_actor = use_actor
if self.use_actor:
head_node_id = get_head_node_id()
self.runner = RemoteFlotillaRunner.options( # type: ignore
name=FLOTILLA_RUNNER_NAME,
namespace=FLOTILLA_RUNER_NAMESPACE,
get_if_exists=True,
scheduling_strategy=(
ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
node_id=head_node_id,
soft=False,
)
if head_node_id is not None
else "DEFAULT"
),
).remote()
else:
self.runner = LocalFlotillaRunner()
def __init__(self) -> None:
head_node_id = get_head_node_id()
self.runner = RemoteFlotillaRunner.options( # type: ignore
name=FLOTILLA_RUNNER_NAME,
namespace=FLOTILLA_RUNNER_NAMESPACE,
get_if_exists=True,
scheduling_strategy=(
ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
node_id=head_node_id,
soft=False,
)
if head_node_id is not None
else "DEFAULT"
),
).remote()

def stream_plan(
self,
plan: DistributedPhysicalPlan,
partition_sets: dict[str, PartitionSet[ray.ObjectRef]],
) -> Iterator[RayMaterializedResult]:
plan_id = plan.id()
if self.use_actor:
ray.get(self.runner.run_plan.remote(plan, partition_sets))
while True:
materialized_result = ray.get(self.runner.get_next_partition.remote(plan_id))
if materialized_result is None:
break
yield materialized_result
else:
self.runner.run_plan(plan, partition_sets)
while True:
materialized_result = self.runner.get_next_partition(plan_id)
if materialized_result is None:
break
yield materialized_result
ray.get(self.runner.run_plan.remote(plan, partition_sets))
while True:
materialized_result = ray.get(self.runner.get_next_partition.remote(plan_id))
if materialized_result is None:
break
yield materialized_result
2 changes: 1 addition & 1 deletion daft/runners/ray_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1390,7 +1390,7 @@ def run_iter(
yield from self._execute_plan(builder, daft_execution_config, results_buffer_size)
else:
if self.flotilla_plan_runner is None:
self.flotilla_plan_runner = FlotillaRunner(self.ray_client_mode)
self.flotilla_plan_runner = FlotillaRunner()
yield from self.flotilla_plan_runner.stream_plan(
distributed_plan, self._part_set_cache.get_all_partition_sets()
)
Expand Down
9 changes: 3 additions & 6 deletions src/daft-distributed/src/python/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,17 +98,15 @@ impl_bincode_py_state_serialization!(PyDistributedPhysicalPlan);
#[pyclass(module = "daft.daft", name = "DistributedPhysicalPlanRunner", frozen)]
struct PyDistributedPhysicalPlanRunner {
runner: Arc<PlanRunner<RaySwordfishWorker>>,
on_ray_actor: bool,
}

#[pymethods]
impl PyDistributedPhysicalPlanRunner {
#[new]
fn new(py: Python, on_ray_actor: bool) -> PyResult<Self> {
fn new(py: Python) -> PyResult<Self> {
let worker_manager = RayWorkerManager::try_new(py)?;
Ok(Self {
runner: Arc::new(PlanRunner::new(Arc::new(worker_manager))),
on_ray_actor,
})
}

Expand All @@ -130,9 +128,8 @@ impl PyDistributedPhysicalPlanRunner {
})
.collect();

let mut subscribers: Vec<Box<dyn StatisticsSubscriber>> = vec![Box::new(
FlotillaProgressBar::try_new(py, self.on_ray_actor)?,
)];
let mut subscribers: Vec<Box<dyn StatisticsSubscriber>> =
vec![Box::new(FlotillaProgressBar::try_new(py)?)];

tracing::info!("Checking DAFT_DASHBOARD_URL environment variable");
match std::env::var("DAFT_DASHBOARD_URL") {
Expand Down
6 changes: 2 additions & 4 deletions src/daft-distributed/src/python/progress_bar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,10 @@ pub(crate) struct FlotillaProgressBar {
}

impl FlotillaProgressBar {
pub fn try_new(py: Python, use_ray_tqdm: bool) -> PyResult<Self> {
pub fn try_new(py: Python) -> PyResult<Self> {
let progress_bar_module = py.import(pyo3::intern!(py, "daft.runners.progress_bar"))?;
let progress_bar_class = progress_bar_module.getattr(pyo3::intern!(py, "ProgressBar"))?;
let progress_bar = progress_bar_class
.call1((use_ray_tqdm,))?
.extract::<PyObject>()?;
let progress_bar = progress_bar_class.call1((true,))?.extract::<PyObject>()?;
Ok(Self {
progress_bar_pyobject: progress_bar,
})
Expand Down
24 changes: 24 additions & 0 deletions tests/dataframe/test_async_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations


def test_dataframe_running_in_async_context():
import asyncio

import daft

async def main():
@daft.func
async def add_one(x: int) -> int:
await asyncio.sleep(0.1)
return x + 1

df = (
daft.range(100, partitions=10)
.where(daft.col("id") % 2 == 0)
.with_column("id", add_one(daft.col("id")))
.limit(10)
.to_pydict()
)
assert df == {"id": [1, 3, 5, 7, 9, 11, 13, 15, 17, 19]}

asyncio.run(main())
Loading